Skip to content

Commit

Permalink
better implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
aviatesk committed Jul 24, 2024
1 parent 876e5c7 commit 0b253e9
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 75 deletions.
181 changes: 120 additions & 61 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# This file is a part of Julia. License is MIT: https://julialang.org/license

struct SlotRefinement
slot::SlotNumber
typ::Any
SlotRefinement(slot::SlotNumber, @nospecialize(typ)) = new(slot, typ)
end

# See if the inference result of the current statement's result value might affect
# the final answer for the method (aside from optimization potential and exceptions).
# To do that, we need to check both for slot assignment and SSA usage.
Expand Down Expand Up @@ -39,16 +45,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
multiple_matches = napplicable > 1
fargs = arginfo.fargs
all_effects = EFFECTS_TOTAL
if fargs !== nothing
# keeps refinement information on slot types obtained from call signature
refine_targets = Union{Nothing,StmtChange}[]
for i = 1:length(fargs)
x = fargs[i]
push!(refine_targets, isa(x, SlotNumber) ? StmtChange(x, Bottom) : nothing)
end
else
refine_targets = nothing
end
slotrefinements = nothing # keeps refinement information on slot types obtained from call signature

for i in 1:napplicable
match = applicable[i]::MethodMatch
Expand Down Expand Up @@ -172,13 +169,8 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
conditionals[2][i] = tmerge(𝕃ᵢ, conditionals[2][i], cnd.elsetype)
end
end
if refine_targets !== nothing
for i in 1:length(refine_targets)
target = refine_targets[i]
if target !== nothing
refine_targets[i] = StmtChange(target.slot, tmerge(𝕃ᵢ, fieldtype(sig, i), target.typ))
end
end
if sv isa InferenceState && fargs !== nothing
slotrefinements = collect_slot_refinements!(𝕃ᵢ, slotrefinements, sig, argtypes, fargs, sv)
end
if bail_out_call(interp, InferenceLoopState(sig, rettype, all_effects), sv)
add_remark!(interp, sv, "Call inference reached maximally imprecise information. Bailing on.")
Expand All @@ -195,7 +187,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# there is unanalyzed candidate, widen type and effects to the top
rettype = excttype = Any
all_effects = Effects()
refine_targets = nothing
slotrefinements = nothing
elseif isa(matches, MethodMatches) ? (!matches.fullmatch || any_ambig(matches)) :
(!all(matches.fullmatches) || any_ambig(matches))
# Account for the fact that we may encounter a MethodError with a non-covered or ambiguous signature.
Expand All @@ -222,18 +214,9 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

# if refinement information on slot types is available, apply it now
anyrefined = false
if rettype !== Bottom && refine_targets !== nothing
for target in refine_targets
if target !== nothing
if target.typ !== Bottom
push!(sv.curr_stmt_changes, target)
anyrefined = true # TODO limit this when t ⋤ old
end
end
end
end
anyrefined = (sv isa InferenceState && slotrefinements !== nothing &&
check_slot_refinements!(𝕃ᵢ, slotrefinements, sv))
anyrefined || (slotrefinements = nothing)
if call_result_unused(si) && !(rettype === Bottom)
add_remark!(interp, sv, "Call result type was widened because the return value is unused")
# We're mainly only here because the optimizer might want this code,
Expand All @@ -258,7 +241,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
end

return CallMeta(rettype, excttype, all_effects, info)
return CallMeta(rettype, excttype, all_effects, info, slotrefinements)
end

struct FailedMethodMatch
Expand Down Expand Up @@ -523,6 +506,59 @@ function conditional_argtype(𝕃ᵢ::AbstractLattice, @nospecialize(rt), @nospe
end
end

function collect_slot_refinements!(𝕃ᵢ::AbstractLattice, slotrefinements::Union{Nothing,Vector{Any}},
@nospecialize(sig), argtypes::Vector{Any}, fargs::Vector{Any}, sv::InferenceState)
= strictpartialorder(𝕃ᵢ)
for i = 1:length(fargs)
fargᵢ = fargs[i]
if fargᵢ isa SlotNumber
fidx = slot_id(fargᵢ)
argt = widenslotwrapper(argtypes[i])
if isvarargtype(argt)
@assert fieldcount(sig) == i
argt = unwrapva(argt)
end
sigt = fieldtype(sig, i)
newtyp = tmeet(𝕃ᵢ, argt, sigt)
oldtyp = sv.currstate[fidx].typ
if newtyp oldtyp
if slotrefinements === nothing
slotrefinements = fill!(Vector{Any}(undef, length(sv.slottypes)), true)
end
oldnewtyp = slotrefinements[fidx]
if oldnewtyp === true
slotrefinements[fidx] = newtyp
elseif oldnewtyp === false
# can't refine this slot anymore
else
slotrefinements[fidx] = tmerge(𝕃ᵢ, newtyp, oldnewtyp)
end
elseif slotrefinements !== nothing
slotrefinements[fidx] = false
end
end
end
return slotrefinements
end

# check if type refinement information on local slots from matching method signatures
# is really worthwhile to propagate
function check_slot_refinements!(𝕃ᵢ::AbstractLattice, slotrefinements::Vector{Any}, sv::InferenceState)
anyrefined = false
= strictpartialorder(𝕃ᵢ)
for i = 1:length(slotrefinements)
newtyp = slotrefinements[i]
newtyp isa Bool && continue
oldtyp = sv.currstate[i].typ
if newtyp oldtyp
anyrefined |= true
else
slotrefinements[i] = false
end
end
return anyrefined
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype),
all_effects::Effects, anyrefined::Bool, edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
Expand Down Expand Up @@ -1855,12 +1891,7 @@ function abstract_call_builtin(interp::AbstractInterpreter, f::Builtin, (; fargs
ft = popfirst!(argtypes)
rt = builtin_tfunction(interp, f, argtypes, sv)
pushfirst!(argtypes, ft)
if f === typeassert
# perform very limited back-propagation of invariants after this type assertion
if rt !== Bottom && isa(fargs, Vector{Any}) && (x2 = fargs[2]; isa(x2, SlotNumber))
push!(sv.curr_stmt_changes, StmtChange(x2, rt))
end
elseif has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
if has_mustalias(𝕃ᵢ) && f === getfield && isa(fargs, Vector{Any}) && la 3
a3 = argtypes[3]
if isa(a3, Const)
if rt !== Bottom && !isalreadyconst(rt)
Expand Down Expand Up @@ -2161,7 +2192,17 @@ function abstract_call_known(interp::AbstractInterpreter, @nospecialize(f),
exct = builtin_exct(𝕃ᵢ, f, argtypes, rt)
end
pushfirst!(argtypes, ft)
return CallMeta(rt, exct, effects, NoCallInfo())
refinements = nothing
if sv isa InferenceState && f === typeassert
# perform very limited back-propagation of invariants after this type assertion
if rt !== Bottom && isa(fargs, Vector{Any})
farg2 = fargs[2]
if farg2 isa SlotNumber
refinements = SlotRefinement(farg2, rt)
end
end
end
return CallMeta(rt, exct, effects, NoCallInfo(), refinements)
elseif isa(f, Core.OpaqueClosure)
# calling an OpaqueClosure about which we have no information returns no information
return CallMeta(typeof(f).parameters[2], Any, Effects(), NoCallInfo())
Expand Down Expand Up @@ -2451,10 +2492,14 @@ function collect_argtypes(interp::AbstractInterpreter, ea::Vector{Any}, vtypes::
end

struct RTEffects
rt
exct
rt::Any
exct::Any
effects::Effects
RTEffects(@nospecialize(rt), @nospecialize(exct), effects::Effects) = new(rt, exct, effects)
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function RTEffects(rt, exct, effects::Effects, refinements=nothing)
@nospecialize rt exct refinements
return new(rt, exct, effects, refinements)
end
end

function abstract_call(interp::AbstractInterpreter, arginfo::ArgInfo, sv::InferenceState)
Expand All @@ -2476,8 +2521,8 @@ function abstract_eval_call(interp::AbstractInterpreter, e::Expr, vtypes::Union{
return RTEffects(Bottom, Any, Effects())
end
arginfo = ArgInfo(ea, argtypes)
(; rt, exct, effects) = abstract_call(interp, arginfo, sv)
return RTEffects(rt, exct, effects)
(; rt, exct, effects, refinements) = abstract_call(interp, arginfo, sv)
return RTEffects(rt, exct, effects, refinements)
end

function abstract_eval_new(interp::AbstractInterpreter, e::Expr, vtypes::Union{VarTable,Nothing},
Expand Down Expand Up @@ -2819,9 +2864,9 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
rt = old_rt === NOT_FOUND ? rt : tmerge(typeinf_lattice(interp), old_rt, rt)
return RTEffects(rt, Union{}, EFFECTS_TOTAL)
end
(; rt, exct, effects) = abstract_eval_special_value(interp, e, vtypes, sv)
(; rt, exct, effects, refinements) = abstract_eval_special_value(interp, e, vtypes, sv)
else
(; rt, exct, effects) = abstract_eval_statement_expr(interp, e, vtypes, sv)
(; rt, exct, effects, refinements) = abstract_eval_statement_expr(interp, e, vtypes, sv)
if effects.noub === NOUB_IF_NOINBOUNDS
if has_curr_ssaflag(sv, IR_FLAG_INBOUNDS)
effects = Effects(effects; noub=ALWAYS_FALSE)
Expand Down Expand Up @@ -2851,7 +2896,7 @@ function abstract_eval_statement(interp::AbstractInterpreter, @nospecialize(e),
set_curr_ssaflag!(sv, flags_for_effects(effects), IR_FLAGS_EFFECTS)
merge_effects!(interp, sv, effects)

return RTEffects(rt, exct, effects)
return RTEffects(rt, exct, effects, refinements)
end

function override_effects(effects::Effects, override::EffectsOverride)
Expand Down Expand Up @@ -3100,7 +3145,12 @@ struct BasicStmtChange
rt::Any # extended lattice element or `nothing` - `nothing` if this statement may not be used as an SSA Value
exct::Any
# TODO effects::Effects
BasicStmtChange(changes::Union{Nothing,StateUpdate}, @nospecialize(rt), @nospecialize(exct)) = new(changes, rt, exct)
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function BasicStmtChange(changes::Union{Nothing,StateUpdate}, rt::Any, exct::Any,
refinements=nothing)
@nospecialize rt exct refinements
return new(changes, rt, exct, refinements)
end
end

@inline function abstract_eval_basic_statement(interp::AbstractInterpreter,
Expand All @@ -3115,9 +3165,9 @@ end
changes = nothing
hd = stmt.head
if hd === :(=)
(; rt, exct) = abstract_eval_statement(interp, stmt.args[2], pc_vartable, frame)
(; rt, exct, refinements) = abstract_eval_statement(interp, stmt.args[2], pc_vartable, frame)
if rt === Bottom
return BasicStmtChange(nothing, Bottom, exct)
return BasicStmtChange(nothing, Bottom, exct, refinements)
end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
Expand All @@ -3127,7 +3177,7 @@ end
elseif !isa(lhs, SSAValue)
merge_effects!(interp, frame, EFFECTS_UNKNOWN)
end
return BasicStmtChange(changes, rt, exct)
return BasicStmtChange(changes, rt, exct, refinements)
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
Expand All @@ -3139,8 +3189,8 @@ end
is_meta_expr(stmt)))
return BasicStmtChange(nothing, Nothing, Bottom)
else
(; rt, exct) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
return BasicStmtChange(nothing, rt, exct)
(; rt, exct, refinements) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
return BasicStmtChange(nothing, rt, exct, refinements)
end
end

Expand Down Expand Up @@ -3258,7 +3308,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end

states = frame.bb_vartables
currstate = copy(states[currbb]::VarTable)
frame.currstate = currstate = copy(states[currbb]::VarTable)
while currbb <= nbbs
delete!(W, currbb)
bbstart = first(bbs[currbb].stmts)
Expand Down Expand Up @@ -3402,7 +3452,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# Fall through terminator - treat as regular stmt
end
# Process non control-flow statements
(; changes, rt, exct) = abstract_eval_basic_statement(interp,
(; changes, rt, exct, refinements) = abstract_eval_basic_statement(interp,
stmt, currstate, frame)
if !has_curr_ssaflag(frame, IR_FLAG_NOTHROW)
if exct !== Union{}
Expand All @@ -3423,17 +3473,26 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
if changes !== nothing
stoverwrite1!(currstate, changes)
end
while !isempty(frame.curr_stmt_changes)
stmtchange = pop!(frame.curr_stmt_changes)
if changes !== nothing && stmtchange.slot == changes.var
continue # type propagation from statement (like assignment) should have the precedence
function apply_refinement!(slot::SlotNumber, @nospecialize(newtyp))
if changes !== nothing && slot == changes.var
return # type propagation from statement (like assignment) should have the precedence
end
vtype = currstate[slot_id(stmtchange.slot)]
if (𝕃ᵢ, stmtchange.typ, vtype.typ)
stmtupdate = StateUpdate(stmtchange.slot, VarState(stmtchange.typ, vtype.undef), false)
vtype = currstate[slot_id(slot)]
oldtyp = vtype.typ
if (𝕃ᵢ, newtyp, oldtyp)
stmtupdate = StateUpdate(slot, VarState(newtyp, vtype.undef), false)
stoverwrite1!(currstate, stmtupdate)
end
end
if refinements isa SlotRefinement
apply_refinement!(refinements.slot, refinements.typ)
elseif refinements isa Vector{Any}
for i = 1:length(refinements)
slotrefinement = refinements[i]
slotrefinement isa Bool && continue
apply_refinement!(SlotNumber(i), slotrefinement)
end
end
if rt === nothing
ssavaluetypes[currpc] = Any
continue
Expand Down
15 changes: 3 additions & 12 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ Note that `InferenceState` will maintain multiple `VarTable`s at each SSA statem
to enable flow-sensitive analysis.
"""
const VarTable = Vector{VarState}
const initial_state = VarState[]

const CACHE_MODE_NULL = 0x00 # not cached, without optimization
const CACHE_MODE_GLOBAL = 0x01 << 0 # cached globally, optimization allowed
Expand All @@ -227,12 +228,6 @@ struct HandlerInfo
handler_at::Vector{Tuple{Int,Int}} # tuple of current (handler, exception stack) value at the pc
end

struct StmtChange
slot::SlotNumber
typ
StmtChange(slot::SlotNumber, @nospecialize(typ)) = new(slot, typ)
end

mutable struct InferenceState
#= information about this method instance =#
linfo::MethodInstance
Expand All @@ -255,9 +250,7 @@ mutable struct InferenceState
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
stmt_info::Vector{CallInfo}
# additional state updates at current statement made by means other than the assignment
# e.g. type information refinement from `typeassert` call itself
curr_stmt_changes::Vector{StmtChange}
currstate::VarTable

#= intermediate states for interprocedural abstract interpretation =#
pclimitations::IdSet{InferenceState} # causes of precision restrictions (LimitedAccuracy) on currpc ssavalue
Expand Down Expand Up @@ -310,8 +303,6 @@ mutable struct InferenceState
stmt_edges = Vector{Vector{Any}}(undef, nstmts)
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]

curr_stmt_changes = StmtChange[]

nslots = length(src.slotflags)
slottypes = Vector{Any}(undef, nslots)
bb_vartables = Union{Nothing,VarTable}[ nothing for i = 1:length(cfg.blocks) ]
Expand Down Expand Up @@ -361,7 +352,7 @@ mutable struct InferenceState

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, method_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, curr_stmt_changes,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info, initial_state,
pclimitations, limitations, cycle_backedges, callers_in_cycle, dont_work_on_me, parent,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
Expand Down
6 changes: 6 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@ struct CallMeta
exct::Any
effects::Effects
info::CallInfo
refinements # ::Union{Nothing,SlotRefinement,Vector{Any}}
function CallMeta(rt::Any, exct::Any, effects::Effects, info::CallInfo,
refinements=nothing)
@nospecialize rt exct info
return new(rt, exct, effects, info, refinements)
end
end

struct NoCallInfo <: CallInfo end
Expand Down
Loading

0 comments on commit 0b253e9

Please sign in to comment.