Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: enable constant propagation for invoked calls, fixes #41024 #41383

Merged
merged 4 commits into from
Jun 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 37 additions & 14 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1135,7 +1135,8 @@ function abstract_call_unionall(argtypes::Vector{Any})
end

function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState)
ft = widenconst(argtype_by_index(argtypes, 2))
ft′ = argtype_by_index(argtypes, 2)
ft = widenconst(ft′)
ft === Bottom && return CallMeta(Bottom, false)
(types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3))
types === Bottom && return CallMeta(Bottom, false)
Expand All @@ -1149,15 +1150,30 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv:
nargtype = Tuple{ft, nargtype.parameters...}
argtype = Tuple{ft, argtype.parameters...}
result = findsup(types, method_table(interp))
if result === nothing
return CallMeta(Any, false)
end
result === nothing && return CallMeta(Any, false)
method, valid_worlds = result
update_valid_age!(sv, valid_worlds)
(ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector
rt, edge = typeinf_edge(interp, method, ti, env, sv)
(; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv)
edge !== nothing && add_backedge!(edge::MethodInstance, sv)
return CallMeta(rt, InvokeCallInfo(MethodMatch(ti, env, method, argtype <: method.sig)))
match = MethodMatch(ti, env, method, argtype <: method.sig)
# try constant propagation with manual inlinings of some of the heuristics
# since some checks within `abstract_call_method_with_const_args` seem a bit costly
const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing))
argtypes′ = argtypes[4:end]
const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing))
pushfirst!(argtypes′, ft)
# # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons
# for i in 1:length(argtypes′)
# t, a = ti.parameters[i], argtypes′[i]
# argtypes′[i] = t ⊑ a ? t : a
# end
const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false)
if const_rt !== rt && const_rt ⊑ rt
return CallMeta(const_rt, InvokeCallInfo(match, const_result))
else
return CallMeta(rt, InvokeCallInfo(match, nothing))
end
end

# call where the function is known exactly
Expand Down Expand Up @@ -1291,17 +1307,12 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS)
#print("call ", e.args[1], argtypes, "\n\n")
ft = argtypes[1]
if isa(ft, Const)
f = ft.val
elseif isconstType(ft)
f = ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
f = ft.instance
elseif isa(ft, PartialOpaque)
f = argtype_to_function(ft)
if isa(ft, PartialOpaque)
return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv)
elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure)
return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false)
else
elseif f === nothing
# non-constant function, but the number of arguments is known
# and the ft is not a Builtin or IntrinsicFunction
if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{}
Expand All @@ -1313,6 +1324,18 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{
return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods)
end

function argtype_to_function(@nospecialize(ft))
if isa(ft, Const)
return ft.val
elseif isconstType(ft)
return ft.parameters[1]
elseif isa(ft, DataType) && isdefined(ft, :instance)
return ft.instance
else
return nothing
end
end

function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool)
isref = false
if T === Bottom
Expand Down
16 changes: 13 additions & 3 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1049,12 +1049,12 @@ is_builtin(s::Signature) =
isa(s.f, Builtin) ||
s.ft ⊑ Builtin

function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo,
function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo,
state::InliningState, todo::Vector{Pair{Int, Any}})
stmt = ir.stmts[idx][:inst]
calltype = ir.stmts[idx][:type]

if !info.match.fully_covers
if !match.fully_covers
# TODO: We could union split out the signature check and continue on
return nothing
end
Expand All @@ -1064,7 +1064,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn
atypes = atypes[4:end]
pushfirst!(atypes, atype0)

result = analyze_method!(info.match, atypes, state, calltype)
if isa(result, InferenceResult)
item = InliningTodo(result, atypes, calltype)
validate_sparams(item.mi.sparam_vals) || return nothing
if argtypes_to_type(atypes) <: item.mi.def.sig
state.mi_cache !== nothing && (item = resolve_todo(item, state))
handle_single_case!(ir, stmt, idx, item, true, todo)
return nothing
end
end

result = analyze_method!(match, atypes, state, calltype)
handle_single_case!(ir, stmt, idx, result, true, todo)
return nothing
end
Expand Down
1 change: 1 addition & 0 deletions base/compiler/stmtinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ method being processed.
"""
struct InvokeCallInfo
match::MethodMatch
result::Union{Nothing,InferenceResult}
end

struct OpaqueClosureCallInfo
Expand Down
40 changes: 40 additions & 0 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3355,3 +3355,43 @@ let
Expr(:opaque_closure_method, nothing, 2, LineNumberNode(0, nothing), ci)))(true, 1.0)
@test Base.return_types(oc, Tuple{}) == Any[Float64]
end

@testset "constant prop' on `invoke` calls" begin
m = Module()

# simple cases
@eval m begin
f(a::Any, sym::Bool) = sym ? Any : :any
f(a::Number, sym::Bool) = sym ? Number : :number
end
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Any, true::Bool)
end) == Any[Type{Any}]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Number, true::Bool)
end) == Any[Type{Number}]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Any, false::Bool)
end) == Any[Symbol]
@test (@eval m Base.return_types((Any,)) do a
Base.@invoke f(a::Number, false::Bool)
end) == Any[Symbol]

# https://github.com/JuliaLang/julia/issues/41024
@eval m begin
# mixin, which expects common field `x::Int`
abstract type AbstractInterface end
Base.getproperty(x::AbstractInterface, sym::Symbol) =
sym === :x ? getfield(x, sym)::Int :
return getfield(x, sym) # fallback

# extended mixin, which expects additional field `y::Rational{Int}`
abstract type AbstractInterfaceExtended <: AbstractInterface end
Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol) =
sym === :y ? getfield(x, sym)::Rational{Int} :
return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol)
end
@test (@eval m Base.return_types((AbstractInterfaceExtended,)) do x
x.x
end) == Any[Int]
end