Skip to content

Commit

Permalink
Extend CallInfo interface with add_unhandled_case_edges!
Browse files Browse the repository at this point in the history
We need to add MT edges for any unhandled edges in the union-split of
a `CallInfo`
  • Loading branch information
topolarity committed Sep 9, 2024
1 parent b5e10c8 commit 75f9d08
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 11 deletions.
6 changes: 3 additions & 3 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -331,15 +331,15 @@ function find_union_split_method_matches(interp::AbstractInterpreter, argtypes::
end
valid_worlds = intersect(valid_worlds, matches.valid_worlds)
thisfullmatch = any(match::MethodMatch->match.fully_covers, matches)
found = false
mt_found = false
for (i, mt′) in enumerate(mts)
if mt′ === mt
fullmatches[i] &= thisfullmatch
found = true
mt_found = true
break
end
end
if !found
if !mt_found
push!(mts, mt)
push!(fullmatches, thisfullmatch)
end
Expand Down
9 changes: 1 addition & 8 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1399,14 +1399,7 @@ function compute_inlining_cases(@nospecialize(info::CallInfo), flag::UInt32, sig
if !fully_covered
atype = argtypes_to_type(sig.argtypes)
# We will emit an inline MethodError so we need a backedge to the MethodTable
unwrapped_info = info isa ConstCallInfo ? info.call : info
if unwrapped_info isa UnionSplitInfo
for (fullmatch, mt) in zip(unwrapped_info.fullmatches, unwrapped_info.mts)
!fullmatch && push!(state.edges, mt, atype)
end
elseif unwrapped_info isa MethodMatchInfo
push!(state.edges, unwrapped_info.mt, atype)
else @assert false end
add_unhandled_case_edges!(state.edges, info, atype)
end
elseif !isempty(cases)
# if we've not seen all candidates, union split is valid only for dispatch tuples
Expand Down
7 changes: 7 additions & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ end
nsplit_impl(info::MethodMatchInfo) = 1
getsplit_impl(info::MethodMatchInfo, idx::Int) = (@assert idx == 1; info.results)
getresult_impl(::MethodMatchInfo, ::Int) = nothing
add_unhandled_case_edges_impl(edges::Vector{Any}, info::MethodMatchInfo, @nospecialize(atype)) = (!info.fullmatch && push!(edges, info.mt, atype); )

"""
info::UnionSplitInfo <: CallInfo
Expand Down Expand Up @@ -66,6 +67,11 @@ end
nsplit_impl(info::UnionSplitInfo) = length(info.matches)
getsplit_impl(info::UnionSplitInfo, idx::Int) = info.matches[idx]
getresult_impl(::UnionSplitInfo, ::Int) = nothing
function add_unhandled_case_edges_impl(edges::Vector{Any}, info::UnionSplitInfo, @nospecialize(atype))
for (mt, fullmatch) in zip(info.mts, info.fullmatches)
!fullmatch && push!(edges, mt, atype)
end
end

abstract type ConstResult end

Expand Down Expand Up @@ -109,6 +115,7 @@ end
nsplit_impl(info::ConstCallInfo) = nsplit(info.call)
getsplit_impl(info::ConstCallInfo, idx::Int) = getsplit(info.call, idx)
getresult_impl(info::ConstCallInfo, idx::Int) = info.results[idx]
add_unhandled_case_edges_impl(edges::Vector{Any}, info::ConstCallInfo, @nospecialize(atype)) = add_unhandled_case_edges!(edges, info.call, atype)

"""
info::MethodResultPure <: CallInfo
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,9 +451,12 @@ abstract type CallInfo end
nsplit(info::CallInfo) = nsplit_impl(info)::Union{Nothing,Int}
getsplit(info::CallInfo, idx::Int) = getsplit_impl(info, idx)::MethodLookupResult
getresult(info::CallInfo, idx::Int) = getresult_impl(info, idx)
add_unhandled_case_edges!(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = add_unhandled_case_edges_impl(edges, info, atype)


nsplit_impl(::CallInfo) = nothing
getsplit_impl(::CallInfo, ::Int) = error("unexpected call into `getsplit`")
getresult_impl(::CallInfo, ::Int) = nothing
add_unhandled_case_edges_impl(edges::Vector{Any}, info::CallInfo, @nospecialize(atype)) = nothing

@specialize
9 changes: 9 additions & 0 deletions test/compiler/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,6 +409,7 @@ end
CC.nsplit_impl(info::NoinlineCallInfo) = CC.nsplit(info.info)
CC.getsplit_impl(info::NoinlineCallInfo, idx::Int) = CC.getsplit(info.info, idx)
CC.getresult_impl(info::NoinlineCallInfo, idx::Int) = CC.getresult(info.info, idx)
CC.add_unhandled_case_edges_impl(edges::Vector{Any}, info::NoinlineCallInfo, @nospecialize(atype)) = CC.add_unhandled_case_edges!(edges, info.info, atype)

function CC.abstract_call(interp::NoinlineInterpreter,
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.InferenceState, max_methods::Int)
Expand All @@ -431,6 +432,8 @@ end
@inline function inlined_usually(x, y, z)
return x * y + z
end
foo_split(x::Float64) = 1
foo_split(x::Int) = 2

# check if the inlining algorithm works as expected
let src = code_typed1((Float64,Float64,Float64)) do x, y, z
Expand All @@ -444,6 +447,7 @@ let NoinlineModule = Module()
main_func(x, y, z) = inlined_usually(x, y, z)
@eval NoinlineModule noinline_func(x, y, z) = $inlined_usually(x, y, z)
@eval OtherModule other_func(x, y, z) = $inlined_usually(x, y, z)
@eval NoinlineModule bar_split_error() = $foo_split(Core.compilerbarrier(:type, nothing))

interp = NoinlineInterpreter(Set((NoinlineModule,)))

Expand Down Expand Up @@ -473,6 +477,11 @@ let NoinlineModule = Module()
@test count(isinvoke(:inlined_usually), src.code) == 0
@test count(iscall((src, inlined_usually)), src.code) == 0
end

let src = code_typed1(NoinlineModule.bar_split_error)
@test count(iscall((src, foo_split)), src.code) == 0
@test count(iscall((src, Core.throw_methoderror)), src.code) > 0
end
end

# Make sure that Core.Compiler has enough NamedTuple infrastructure
Expand Down

0 comments on commit 75f9d08

Please sign in to comment.