Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add edges vector to CodeInstance/CodeInfo to keep backedges as edges #54894

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
1 change: 0 additions & 1 deletion base/Base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ function Core._hasmethod(@nospecialize(f), @nospecialize(t)) # this function has
return Core._hasmethod(tt)
end


# core operations & types
include("promotion.jl")
include("tuple.jl")
Expand Down
6 changes: 3 additions & 3 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -511,11 +511,11 @@ function CodeInstance(
mi::MethodInstance, owner, @nospecialize(rettype), @nospecialize(exctype), @nospecialize(inferred_const),
@nospecialize(inferred), const_flags::Int32, min_world::UInt, max_world::UInt,
effects::UInt32, @nospecialize(analysis_results),
relocatability::UInt8, edges::Union{DebugInfo,Nothing})
relocatability::UInt8, di::Union{DebugInfo,Nothing}, edges::SimpleVector)
return ccall(:jl_new_codeinst, Ref{CodeInstance},
(Any, Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any),
(Any, Any, Any, Any, Any, Any, Int32, UInt, UInt, UInt32, Any, UInt8, Any, Any),
mi, owner, rettype, exctype, inferred_const, inferred, const_flags, min_world, max_world,
effects, analysis_results, relocatability, edges)
effects, analysis_results, relocatability, di, edges)
end
GlobalRef(m::Module, s::Symbol) = ccall(:jl_module_globalref, Ref{GlobalRef}, (Any, Any), m, s)
Module(name::Symbol=:anonymous, std_imports::Bool=true, default_names::Bool=true) = ccall(:jl_f_new_module, Ref{Module}, (Any, Bool, Bool), name, std_imports, default_names)
Expand Down
101 changes: 26 additions & 75 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
gfresult = Future{CallMeta}()
# intermediate work for computing gfresult
rettype = exctype = Bottom
edges = MethodInstance[]
conditionals = nothing # keeps refinement information of call argument types when the return type is boolean
seenall = true
const_results = nothing # or const_results::Vector{Union{Nothing,ConstResult}} if any const results are available
Expand Down Expand Up @@ -95,7 +94,7 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
#end
mresult = abstract_call_method(interp, method, sig, match.sparams, multiple_matches, si, sv)::Future
function handle1(interp, sv)
local (; rt, exct, edge, effects, volatile_inf_result) = mresult[]
local (; rt, exct, effects, volatile_inf_result) = mresult[]
this_conditional = ignorelimited(rt)
this_rt = widenwrappedconditional(rt)
this_exct = exct
Expand All @@ -119,17 +118,17 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# e.g. in cases when there are cycles but cached result is still accurate
this_conditional = this_const_conditional
this_rt = this_const_rt
(; effects, const_result, edge) = const_call_result
(; effects, const_result) = const_call_result
elseif is_better_effects(const_call_result.effects, effects)
(; effects, const_result, edge) = const_call_result
(; effects, const_result) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded because the result was wider than inference")
end
# Treat the exception type separately. Currently, constprop often cannot determine the exception type
# because consistent-cy does not apply to exceptions.
if const_call_result.exct ⋤ this_exct
this_exct = const_call_result.exct
(; const_result, edge) = const_call_result
(; const_result) = const_call_result
else
add_remark!(interp, sv, "[constprop] Discarded exception type because result was wider than inference")
end
Expand All @@ -142,7 +141,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
end
const_results[i] = const_result
end
edge === nothing || push!(edges, edge)
@assert !(this_conditional isa Conditional || this_rt isa MustAlias) "invalid lattice element returned from inter-procedural context"
if can_propagate_conditional(this_conditional, argtypes)
# The only case where we need to keep this in rt is where
Expand Down Expand Up @@ -230,8 +228,6 @@ function abstract_call_gf_by_type(interp::AbstractInterpreter, @nospecialize(f),
# and avoid keeping track of a more complex result type.
rettype = Any
end
any_slot_refined = slotrefinements !== nothing
add_call_backedges!(interp, rettype, all_effects, any_slot_refined, edges, matches, atype.contents, sv)
if isa(sv, InferenceState)
# TODO (#48913) implement a proper recursion handling for irinterp:
# This works just because currently the `:terminate` condition guarantees that
Expand Down Expand Up @@ -267,12 +263,6 @@ any_ambig(info::MethodMatchInfo) = any_ambig(info.results)
any_ambig(m::MethodMatches) = any_ambig(m.info)
fully_covering(info::MethodMatchInfo) = info.fullmatch
fully_covering(m::MethodMatches) = fully_covering(m.info)
function add_uncovered_edges!(sv::AbsIntState, info::MethodMatchInfo, @nospecialize(atype))
fully_covering(info) || add_mt_backedge!(sv, info.mt, atype)
nothing
end
add_uncovered_edges!(sv::AbsIntState, matches::MethodMatches, @nospecialize(atype)) =
add_uncovered_edges!(sv, matches.info, atype)

struct UnionSplitMethodMatches
applicable::Vector{Any}
Expand All @@ -284,23 +274,14 @@ any_ambig(info::UnionSplitInfo) = any(any_ambig, info.split)
any_ambig(m::UnionSplitMethodMatches) = any_ambig(m.info)
fully_covering(info::UnionSplitInfo) = all(fully_covering, info.split)
fully_covering(m::UnionSplitMethodMatches) = fully_covering(m.info)
function add_uncovered_edges!(sv::AbsIntState, info::UnionSplitInfo, @nospecialize(atype))
all(fully_covering, info.split) && return nothing
# add mt backedges with removing duplications
for mt in uncovered_method_tables(info)
add_mt_backedge!(sv, mt, atype)
end
end
add_uncovered_edges!(sv::AbsIntState, matches::UnionSplitMethodMatches, @nospecialize(atype)) =
add_uncovered_edges!(sv, matches.info, atype)
function uncovered_method_tables(info::UnionSplitInfo)
mts = MethodTable[]

nmatches(info::MethodMatchInfo) = length(info.results)
function nmatches(info::UnionSplitInfo)
n = 0
for mminfo in info.split
fully_covering(mminfo) && continue
any(mt′::MethodTable->mt′===mminfo.mt, mts) && continue
push!(mts, mminfo.mt)
n += nmatches(mminfo)
end
return mts
return n
end

function find_method_matches(interp::AbstractInterpreter, argtypes::Vector{Any}, @nospecialize(atype);
Expand Down Expand Up @@ -339,7 +320,7 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
end
valid_worlds = intersect(valid_worlds, thismatches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, thismatches)
thisinfo = MethodMatchInfo(thismatches, mt, thisfullmatch)
thisinfo = MethodMatchInfo(thismatches, mt, sig_n, thisfullmatch)
push!(infos, thisinfo)
end
info = UnionSplitInfo(infos)
Expand All @@ -360,7 +341,7 @@ function find_simple_method_matches(interp::AbstractInterpreter, @nospecialize(a
return FailedMethodMatch("Too many methods matched")
end
fullmatch = any(match::MethodMatch->match.fully_covers, matches)
info = MethodMatchInfo(matches, mt, fullmatch)
info = MethodMatchInfo(matches, mt, atype, fullmatch)
return MethodMatches(matches.matches, info, matches.valid_worlds)
end

Expand Down Expand Up @@ -560,31 +541,6 @@ function collect_slot_refinements(𝕃ᵢ::AbstractLattice, applicable::Vector{A
return slotrefinements
end

function add_call_backedges!(interp::AbstractInterpreter, @nospecialize(rettype),
all_effects::Effects, any_slot_refined::Bool, edges::Vector{MethodInstance},
matches::Union{MethodMatches,UnionSplitMethodMatches}, @nospecialize(atype),
sv::AbsIntState)
# don't bother to add backedges when both type and effects information are already
# maximized to the top since a new method couldn't refine or widen them anyway
if rettype === Any
# ignore the `:nonoverlayed` property if `interp` doesn't use overlayed method table
# since it will never be tainted anyway
if !isoverlayed(method_table(interp))
all_effects = Effects(all_effects; nonoverlayed=ALWAYS_FALSE)
end
if all_effects === Effects() && !any_slot_refined
return nothing
end
end
for edge in edges
add_backedge!(sv, edge)
end
# also need an edge to the method table in case something gets
# added that did not intersect with any existing method
add_uncovered_edges!(sv, matches, atype)
return nothing
end

const RECURSION_UNUSED_MSG = "Bounded recursion detected with unused result. Annotated return type may be wider than true result."
const RECURSION_MSG = "Bounded recursion detected. Call was widened to force convergence."
const RECURSION_MSG_HARDLIMIT = "Bounded recursion detected under hardlimit. Call was widened to force convergence."
Expand Down Expand Up @@ -857,13 +813,11 @@ struct ConstCallResults
exct::Any
const_result::ConstResult
effects::Effects
edge::MethodInstance
function ConstCallResults(
@nospecialize(rt), @nospecialize(exct),
const_result::ConstResult,
effects::Effects,
edge::MethodInstance)
return new(rt, exct, const_result, effects, edge)
effects::Effects)
return new(rt, exct, const_result, effects)
end
end

Expand Down Expand Up @@ -1015,9 +969,9 @@ function concrete_eval_call(interp::AbstractInterpreter,
catch e
# The evaluation threw. By :consistent-cy, we're guaranteed this would have happened at runtime.
# Howevever, at present, :consistency does not mandate the type of the exception
return ConstCallResults(Bottom, Any, ConcreteResult(edge, result.effects), result.effects, edge)
return ConstCallResults(Bottom, Any, ConcreteResult(edge, result.effects), result.effects)
end
return ConstCallResults(Const(value), Union{}, ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL, edge)
return ConstCallResults(Const(value), Union{}, ConcreteResult(edge, EFFECTS_TOTAL, value), EFFECTS_TOTAL)
end

# check if there is a cycle and duplicated inference of `mi`
Expand Down Expand Up @@ -1282,7 +1236,7 @@ function semi_concrete_eval_call(interp::AbstractInterpreter,
effects = Effects(effects; noub=ALWAYS_TRUE)
end
exct = refine_exception_type(result.exct, effects)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects, spec_info(irsv)), effects, mi)
return ConstCallResults(rt, exct, SemiConcreteResult(mi, ir, effects, spec_info(irsv)), effects)
end
end
end
Expand All @@ -1291,7 +1245,7 @@ end

const_prop_result(inf_result::InferenceResult) =
ConstCallResults(inf_result.result, inf_result.exc_result, ConstPropResult(inf_result),
inf_result.ipo_effects, inf_result.linfo)
inf_result.ipo_effects)

# return cached result of constant analysis
return_localcache_result(::AbstractInterpreter, inf_result::InferenceResult, ::AbsIntState) =
Expand Down Expand Up @@ -2229,7 +2183,7 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
mresult = abstract_call_method(interp, method, ti, env, false, si, sv)::Future
match = MethodMatch(ti, env, method, argtype <: method.sig)
return Future{CallMeta}(mresult, interp, sv) do result, interp, sv
(; rt, exct, edge, effects, volatile_inf_result) = result
(; rt, exct, effects, volatile_inf_result) = result
res = nothing
sig = match.spec_types
argtypes′ = invoke_rewrite(argtypes)
Expand All @@ -2250,15 +2204,14 @@ function abstract_invoke(interp::AbstractInterpreter, arginfo::ArgInfo, si::Stmt
const_result = volatile_inf_result
if const_call_result !== nothing
if const_call_result.rt ⊑ rt
(; rt, effects, const_result, edge) = const_call_result
(; rt, effects, const_result) = const_call_result
end
if const_call_result.exct ⋤ exct
(; exct, const_result, edge) = const_call_result
(; exct, const_result) = const_call_result
end
end
rt = from_interprocedural!(interp, rt, sv, arginfo, sig)
info = InvokeCallInfo(match, const_result)
edge !== nothing && add_invoke_backedge!(sv, lookupsig, edge)
info = InvokeCallInfo(match, const_result, lookupsig)
if !match.fully_covers
effects = Effects(effects; nothrow=false)
exct = exct ⊔ TypeError
Expand Down Expand Up @@ -2454,19 +2407,19 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
mresult = abstract_call_method(interp, ocmethod, sig, Core.svec(), false, si, sv)
ocsig_box = Core.Box(ocsig)
return Future{CallMeta}(mresult, interp, sv) do result, interp, sv
(; rt, exct, edge, effects, volatile_inf_result, edgecycle) = result
(; rt, exct, effects, volatile_inf_result, edgecycle) = result
𝕃ₚ = ipo_lattice(interp)
⊑, ⋤, ⊔ = partialorder(𝕃ₚ), strictneqpartialorder(𝕃ₚ), join(𝕃ₚ)
const_result = volatile_inf_result
if !edgecycle
const_call_result = abstract_call_method_with_const_args(interp, result,
nothing, arginfo, si, match, sv)
#=f=#nothing, arginfo, si, match, sv)
if const_call_result !== nothing
if const_call_result.rt ⊑ rt
(; rt, effects, const_result, edge) = const_call_result
(; rt, effects, const_result) = const_call_result
end
if const_call_result.exct ⋤ exct
(; exct, const_result, edge) = const_call_result
(; exct, const_result) = const_call_result
end
end
end
Expand All @@ -2481,7 +2434,6 @@ function abstract_call_opaque_closure(interp::AbstractInterpreter,
end
rt = from_interprocedural!(interp, rt, sv, arginfo, match.spec_types)
info = OpaqueClosureCallInfo(match, const_result)
edge !== nothing && add_backedge!(sv, edge)
return CallMeta(rt, exct, effects, info)
end
end
Expand Down Expand Up @@ -3430,7 +3382,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState, nextr
while currpc < bbend
currpc += 1
frame.currpc = currpc
empty_backedges!(frame, currpc)
stmt = frame.src.code[currpc]
# If we're at the end of the basic block ...
if currpc == bbend
Expand Down
62 changes: 7 additions & 55 deletions base/compiler/inferencestate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ mutable struct InferenceState
# TODO: Could keep this sparsely by doing structural liveness analysis ahead of time.
bb_vartables::Vector{Union{Nothing,VarTable}} # nothing if not analyzed yet
ssavaluetypes::Vector{Any}
stmt_edges::Vector{Vector{Any}}
edges::Vector{Any}
stmt_info::Vector{CallInfo}

#= intermediate states for interprocedural abstract interpretation =#
Expand Down Expand Up @@ -302,7 +302,7 @@ mutable struct InferenceState
nssavalues = src.ssavaluetypes::Int
ssavalue_uses = find_ssavalue_uses(code, nssavalues)
nstmts = length(code)
stmt_edges = Vector{Vector{Any}}(undef, nstmts)
edges = []
stmt_info = CallInfo[ NoCallInfo() for i = 1:nstmts ]

nslots = length(src.slotflags)
Expand All @@ -327,7 +327,7 @@ mutable struct InferenceState
unreachable = BitSet()
pclimitations = IdSet{InferenceState}()
limitations = IdSet{InferenceState}()
cycle_backedges = Vector{Tuple{InferenceState,Int}}()
cycle_backedges = Tuple{InferenceState,Int}[]
callstack = AbsIntState[]
tasks = WorkThunk[]

Expand All @@ -350,10 +350,12 @@ mutable struct InferenceState

restrict_abstract_call_sites = isa(def, Module)

parentid = frameid = cycleid = 0

this = new(
mi, world, mod, sptypes, slottypes, src, cfg, spec_info,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, stmt_edges, stmt_info,
tasks, pclimitations, limitations, cycle_backedges, callstack, 0, 0, 0,
currbb, currpc, ip, handler_info, ssavalue_uses, bb_vartables, ssavaluetypes, edges, stmt_info,
tasks, pclimitations, limitations, cycle_backedges, callstack, parentid, frameid, cycleid,
result, unreachable, valid_worlds, bestguess, exc_bestguess, ipo_effects,
restrict_abstract_call_sites, cache_mode, insert_coverage,
interp)
Expand Down Expand Up @@ -754,30 +756,6 @@ function record_ssa_assign!(𝕃ᵢ::AbstractLattice, ssa_id::Int, @nospecialize
return nothing
end

function add_cycle_backedge!(caller::InferenceState, frame::InferenceState)
update_valid_age!(caller, frame.valid_worlds)
backedge = (caller, caller.currpc)
contains_is(frame.cycle_backedges, backedge) || push!(frame.cycle_backedges, backedge)
add_backedge!(caller, frame.linfo)
return frame
end

function get_stmt_edges!(caller::InferenceState, currpc::Int=caller.currpc)
stmt_edges = caller.stmt_edges
if !isassigned(stmt_edges, currpc)
return stmt_edges[currpc] = Any[]
else
return stmt_edges[currpc]
end
end

function empty_backedges!(frame::InferenceState, currpc::Int=frame.currpc)
if isassigned(frame.stmt_edges, currpc)
empty!(frame.stmt_edges[currpc])
end
return nothing
end

function narguments(sv::InferenceState, include_va::Bool=true)
nargs = Int(sv.src.nargs)
if !include_va
Expand Down Expand Up @@ -1008,32 +986,6 @@ function callers_in_cycle(sv::InferenceState)
end
callers_in_cycle(sv::IRInterpretationState) = AbsIntCycle(sv.callstack::Vector{AbsIntState}, 0, 0)

# temporarily accumulate our edges to later add as backedges in the callee
function add_backedge!(caller::InferenceState, mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mi)
end
function add_backedge!(irsv::IRInterpretationState, mi::MethodInstance)
return push!(irsv.edges, mi)
end

function add_invoke_backedge!(caller::InferenceState, @nospecialize(invokesig::Type), mi::MethodInstance)
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), invokesig, mi)
end
function add_invoke_backedge!(irsv::IRInterpretationState, @nospecialize(invokesig::Type), mi::MethodInstance)
return push!(irsv.edges, invokesig, mi)
end

# used to temporarily accumulate our no method errors to later add as backedges in the callee method table
function add_mt_backedge!(caller::InferenceState, mt::MethodTable, @nospecialize(typ))
isa(caller.linfo.def, Method) || return nothing # don't add backedges to toplevel method instance
return push!(get_stmt_edges!(caller), mt, typ)
end
function add_mt_backedge!(irsv::IRInterpretationState, mt::MethodTable, @nospecialize(typ))
return push!(irsv.edges, mt, typ)
end

get_curr_ssaflag(sv::InferenceState) = sv.src.ssaflags[sv.currpc]
get_curr_ssaflag(sv::IRInterpretationState) = sv.ir.stmts[sv.curridx][:flag]

Expand Down
Loading