From e87ad2b13f98151528faf5ab9983f9fbf7b92a3f Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki Date: Thu, 20 May 2021 01:47:30 +0900 Subject: [PATCH] inference: backward constraint propagation from call signatures This PR implements another (limited) backward analysis pass in abstract interpretation; it exploits signatures of matching methods and refines types of slots. Here are couple of examples where these changes will improve the accuracy: > generic function example ```julia addi(a::Integer, b::Integer) = a + b Base.return_types((Any,Any,)) do a, b c = addi(a, b) return a, b, c # now the compiler understands `a::Integer`, `b::Integer` end ``` > `typeassert` example ```julia Base.return_types((Any,)) do a typeassert(a, Int) return a # now the compiler understands `a::Int` end ``` This PR consists of two main parts: 1.) obtain refinement information and back-propagate it, and 2.) apply state updates As for 1., unlike conditional constraints, refinement information isn't encoded within lattice element, but rather they are conveyed by the newly defined field `InferenceState.state_updates`, which is refreshed on each program counter increment. For now refinement information is obtained from call signatures of generic functions and `typeassert`. Finally, in order to apply multiple state updates, this PR extends `StateUpdate` and `stupdate` so that they can hold and apply multiple state updates. --- base/compiler/abstractinterpretation.jl | 95 ++++++++++++++++-- base/compiler/inferencestate.jl | 3 +- base/compiler/typelattice.jl | 43 +++++--- test/compiler/inference.jl | 125 ++++++++++++++++++++++++ 4 files changed, 241 insertions(+), 25 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 0f2e2b78d6be51..fad2e138f925ab 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -110,6 +110,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), any_const_result = false const_results = Union{InferenceResult,Nothing}[] multiple_matches = napplicable > 1 + refine_targets = nothing # keeps refinement information on slot types obtained from call signature + if fargs !== nothing + refine_targets = Union{Nothing,Tuple{SlotNumber,Any}}[] + for x in fargs + push!(refine_targets, isa(x, SlotNumber) ? (x, Bottom) : nothing) + end + end if f !== nothing && napplicable == 1 && is_method_pure(applicable[1]::MethodMatch) val = pure_eval_call(f, argtypes) @@ -197,6 +204,15 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), conditionals[2][i] = tmerge(conditionals[2][i], elsetype) end end + if refine_targets !== nothing + for i in 1:length(refine_targets) + target = refine_targets[i] + if target !== nothing + slot, t = target + refine_targets[i] = (slot, tmerge(fieldtype(sig, i), t)) + end + end + end if bail_out_call(interp, rettype, sv) break end @@ -209,6 +225,12 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), info = ConstCallInfo(info, const_results) end + # refinement information from call signatures is valid only after we succeed in inferring + # all the matching signatures and we should invalidate it if we bailed out early + if seen ≠ napplicable + refine_targets = nothing + end + if rettype isa LimitedAccuracy union!(sv.pclimitations, rettype.causes) rettype = rettype.typ @@ -263,6 +285,18 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end @assert !(rettype isa InterConditional) "invalid lattice element returned from inter-procedural context" + # if refinement information on slot types is available, apply it now + anyrefined = false + if rettype !== Bottom && refine_targets !== nothing + for target in refine_targets + if target !== nothing + slot, t = target + if t !== Bottom + anyrefined |= add_state_update!(slot, t, sv) + end + end + end + end if call_result_unused(sv) && !(rettype === Bottom) add_remark!(interp, sv, "Call result type was widened because the return value is unused") # We're mainly only here because the optimizer might want this code, @@ -273,7 +307,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), # and avoid keeping track of a more complex result type. rettype = Any end - add_call_backedges!(interp, rettype, edges, fullmatch, mts, atype, sv) + add_call_backedges!(interp, anyrefined, rettype, edges, fullmatch, mts, atype, sv) if !isempty(sv.pclimitations) # remove self, if present delete!(sv.pclimitations, sv) for caller in sv.callers_in_cycle @@ -285,13 +319,13 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f), end function add_call_backedges!(interp::AbstractInterpreter, - @nospecialize(rettype), + anyrefined::Bool, @nospecialize(rettype), edges::Vector{MethodInstance}, fullmatch::Vector{Bool}, mts::Vector{Core.MethodTable}, @nospecialize(atype), sv::InferenceState) - if rettype === Any - # for `NativeInterpreter`, we don't add backedges when a new method couldn't refine - # (widen) this type + if !anyrefined && rettype === Any + # for `NativeInterpreter`, we don't add backedges when we've not used refinement + # information from call signature and a new method couldn't refine (widen) this type return end for edge in edges @@ -1000,6 +1034,11 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, fargs::U if 1 <= idx <= length(cti) rt = unwrapva(cti[idx]) end + elseif f === typeassert + # perform very limited back-propagation of invariants after this type asertion + if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber)) + add_state_update!(x2, rt, sv) + end elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) # perform very limited back-propagation of type information for `is` and `isa` if f === isa @@ -1658,6 +1697,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) stmt = frame.src.code[pc] changes = states[pc]::VarTable t = nothing + empty!(frame.state_updates) hd = isa(stmt, Expr) ? stmt.head : nothing @@ -1778,12 +1818,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) frame.src.ssavaluetypes[pc] = t lhs = stmt.args[1] if isa(lhs, SlotNumber) - changes = StateUpdate(lhs, VarState(t, false), changes, false) + changes = StateUpdate([lhs], [VarState(t, false)], changes, false) end elseif hd === :method fname = stmt.args[1] if isa(fname, SlotNumber) - changes = StateUpdate(fname, VarState(Any, false), changes, false) + changes = StateUpdate([fname], [VarState(Any, false)], changes, false) end elseif hd === :inbounds || hd === :meta || hd === :loopinfo || hd === :code_coverage_effect # these do not generate code @@ -1821,6 +1861,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) pc´ > n && break # can't proceed with the fast-path fall-through frame.handler_at[pc´] = frame.cur_hand + changes = collect_state_updates!(changes, frame) newstate = stupdate!(states[pc´], changes) if isa(stmt, GotoNode) && frame.pc´´ < pc´ # if we are processing a goto node anyways, @@ -1846,6 +1887,44 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState) nothing end +function add_state_update!(slot::SlotNumber, @nospecialize(new), frame::InferenceState) + states = frame.stmt_types[frame.currpc]::VarTable + old = ((states[slot_id(slot)])::VarState).typ + if !(old ⊑ new) # new ⋤ old + push!(frame.state_updates, (slot, new)) + return true + end + return false +end + +function collect_state_updates!(changes::StateUpdate, frame::InferenceState) + state_updates = frame.state_updates + vars = changes.vars + vtypes = changes.vtypes + while !isempty(state_updates) + var, typ = pop!(state_updates) + var in vars && continue # state update from lhs assigment should always has the precedence + push!(vars, var) + vtype = VarState(typ, (changes.state[slot_id(var)]::VarState).undef) + push!(vtypes, vtype) + end + return changes +end + +function collect_state_updates!(changes::VarTable, frame::InferenceState) + state_updates = frame.state_updates + isempty(state_updates) && return changes + vars = SlotNumber[] + vtypes = VarState[] + while !isempty(state_updates) + var, typ = pop!(state_updates) + push!(vars, var) + vtype = VarState(typ, (changes[slot_id(var)]::VarState).undef) + push!(vtypes, vtype) + end + return StateUpdate(vars, vtypes, changes, false) +end + function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber) oldtyp = (changes[slot_id(var)]::VarState).typ # approximate test for `typ ∩ oldtyp` being better than `oldtyp` @@ -1853,7 +1932,7 @@ function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNum if ignorelimited(typ) ⊑ ignorelimited(oldtyp) # typ is better unlimited, but we may still need to compute the tmeet with the limit "causes" since we ignored those in the comparison oldtyp isa LimitedAccuracy && (typ = tmerge(typ, LimitedAccuracy(Bottom, oldtyp.causes))) - return StateUpdate(var, VarState(typ, false), changes, true) + return StateUpdate([var], [VarState(typ, false)], changes, true) end return changes end diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index cb5d2009a9171e..bc80e60cc50553 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -21,6 +21,7 @@ mutable struct InferenceState stmt_types::Vector{Union{Nothing, Vector{Any}}} # ::Vector{Union{Nothing, VarTable}} stmt_edges::Vector{Union{Nothing, Vector{Any}}} stmt_info::Vector{Any} + state_updates::Vector{Tuple{SlotNumber,Any}} # additional state update obtained at currpc # return type bestguess #::Type # current active instruction pointers @@ -108,7 +109,7 @@ mutable struct InferenceState sp, slottypes, inmodule, 0, IdSet{InferenceState}(), IdSet{InferenceState}(), src, get_world_counter(interp), valid_worlds, - nargs, s_types, s_edges, stmt_info, + nargs, s_types, s_edges, stmt_info, Tuple{SlotNumber,Any}[], Union{}, W, 1, n, cur_hand, handler_at, n_handlers, ssavalue_uses, throw_blocks, diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index 6391d4029b58e7..795288a6fd2bae 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -84,8 +84,8 @@ end const VarTable = Array{Any,1} struct StateUpdate - var::SlotNumber - vtype::VarState + vars::Vector{SlotNumber} + vtypes::Vector{VarState} state::VarTable conditional::Bool end @@ -320,32 +320,40 @@ ignorelimited(@nospecialize typ) = typ ignorelimited(typ::LimitedAccuracy) = typ.typ function stupdate!(state::Nothing, changes::StateUpdate) - newst = copy(changes.state) - changeid = slot_id(changes.var) - newst[changeid] = changes.vtype + newstate = copy(changes.state) + changeids = Int[] + for (var, vtype) in zip(changes.vars, changes.vtypes) + changeid = slot_id(var) + newstate[changeid] = vtype + push!(changeids, changeid) + end # remove any Conditional for this slot from the vtable # (unless this change is came from the conditional) if !changes.conditional - for i = 1:length(newst) - newtype = newst[i] + for i = 1:length(newstate) + newtype = newstate[i] if isa(newtype, VarState) newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid + if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids newtypetyp = widenwrappedconditional(newtype.typ) - newst[i] = VarState(newtypetyp, newtype.undef) + newstate[i] = VarState(newtypetyp, newtype.undef) end end end end - return newst + return newstate end function stupdate!(state::VarTable, changes::StateUpdate) + changeids = Int[] + for var in changes.vars + push!(changeids, slot_id(var)) + end newstate = nothing - changeid = slot_id(changes.var) for i = 1:length(state) - if i == changeid - newtype = changes.vtype + j = findfirst(==(i), changeids) + if j !== nothing + newtype = changes.vtypes[j] else newtype = changes.state[i] end @@ -354,7 +362,7 @@ function stupdate!(state::VarTable, changes::StateUpdate) # (unless this change is came from the conditional) if !changes.conditional && isa(newtype, VarState) newtypetyp = ignorelimited(newtype.typ) - if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) == changeid + if isa(newtypetyp, Conditional) && slot_id(newtypetyp.var) in changeids newtypetyp = widenwrappedconditional(newtype.typ) newtype = VarState(newtypetyp, newtype.undef) end @@ -385,7 +393,10 @@ stupdate!(state::Nothing, changes::VarTable) = copy(changes) stupdate!(state::Nothing, changes::Nothing) = nothing function stupdate1!(state::VarTable, change::StateUpdate) - changeid = slot_id(change.var) + vars, vtypes = change.vars, change.vtypes + @assert length(vars) == length(vtypes) == 1 + var, vtype = vars[1], vtypes[1] + changeid = slot_id(var) # remove any Conditional for this slot from the catch block vtable # (unless this change is came from the conditional) if !change.conditional @@ -404,7 +415,7 @@ function stupdate1!(state::VarTable, change::StateUpdate) end end # and update the type of it - newtype = change.vtype + newtype = vtype oldtype = state[changeid] if schanged(newtype, oldtype) state[changeid] = smerge(oldtype, newtype) diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 96ab411c654a0a..7335c44ff149b2 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3288,3 +3288,128 @@ end == [Union{Some{Float64}, Some{Int}, Some{UInt8}}] true end end + +@testset "constraint back-propagation from typeassert" begin + @test Base.return_types((Any,)) do a + typeassert(a, Int) + return a + end == Any[Int] + + @test Base.return_types((Any,Bool)) do a, b + if b + typeassert(a, Int64) + else + typeassert(a, Int32) + end + return a + end == Any[Union{Int32,Int64}] +end + +@testset "constraint back-propagation from call signature" begin + # basic case + @test (@eval Module() begin + f(::Int) = return + Base.return_types((Any,)) do a + f(a) + return a + end + end) == Any[Int] + + # union-split case + @test (@eval Module() begin + f(::Int32) = return + f(::Int64) = return + Base.return_types((Any,)) do a + f(a) + return a + end + end) == Any[Union{Int32,Int64}] + + # multiple state updates + @test (@eval Module() begin + f(::Int) = return + g(::Nothing) = return + Base.return_types((Any,Any)) do a, b + f(a); g(b) + return a, b + end + end) == Any[Tuple{Int,Nothing}] + + # refinement should happen only when it's worthwhile + @test (@eval Module() begin + f(::Any) = return + Base.return_types((Integer,)) do a + f(a) + return a + end + end) == Any[Integer] + + # state update on lhs slot (assignment effect should have the precedence) + @test (@eval Module() begin + f(::Int) = return + Base.return_types((Any,)) do a + a = f(a) + return a + end + end) == Any[Nothing] + + # make sure to invalidate an intermediate refinement information when we bail out early + @test (@eval Module() begin + f(::Val{0}) = return 0 + f(::Val{1}) = return undefvar # ::Any + f(::Val{2}) = return 2 + Base.return_types((Any,)) do a + f(a) + return a + end + end) == Any[Any] # shouldn't be `Any[Union{Val{0},Val{2}}]` or something + + # if we see all the matching methods, we don't need to throw away refinement information + # even if it's caught by `bail_out_call` check + if length(methods(+, (Integer, Integer))) > Core.Compiler.InferenceParams().MAX_METHODS + @test (@eval Module() begin + addn(a::Integer, b::Integer) = a + b # too many maching methods, and return type should be annotated as `Any` (and thus caught by `bail_out_call`) + Base.return_types((Any,Any)) do a, b + addn(a, b) + return a, b # ::Tuple{Integer,Integer} + end + end) == Any[Tuple{Integer,Integer}] + end + + # make sure to add backedges when we use call signature constraint + let + m = Module() + @eval m outer(a) = (_inner!(a); return a) + + @test (@eval m begin + _inner!(::Int) = globalvar # ::Any + Base.return_types((Any,)) do a + return outer(a) # ::Int + end + end) == Any[Int] + + # new definition of `_inner!` should invalidate `outer` + # (even if the previous return type is annotated as `Any`) + @test (@eval m begin + _inner!(::Nothing) = globalvar # ::Any + Base.return_types((Any,)) do a + # since inference will bail out at the first matched `_inner!` and so call signature constraint won't be available + return outer(a) # ::Union{Int,Nothing} ideally, but ::Any + end + end) ≠ Any[Int] + end + + # https://github.com/JuliaLang/julia/issues/37866 + @test (@eval Module() begin + function find_first_above_5(v::Vector{Union{Nothing,Float64}}) + for x in v + if x > 5.0 + return x # x > 5.0 is MethodError for Nothing so can assume ::Float64 + end + end + return 0.0 + end + + Base.return_types(find_first_above_5, (Vector{Union{Nothing,Float64}},)) + end) == Any[Float64] +end