diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index dc4663a877d9a..23b00134c6071 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -1135,7 +1135,8 @@ function abstract_call_unionall(argtypes::Vector{Any}) end function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv::InferenceState) - ft = widenconst(argtype_by_index(argtypes, 2)) + ft′ = argtype_by_index(argtypes, 2) + ft = widenconst(ft′) ft === Bottom && return CallMeta(Bottom, false) (types, isexact, isconcrete, istype) = instanceof_tfunc(argtype_by_index(argtypes, 3)) types === Bottom && return CallMeta(Bottom, false) @@ -1149,15 +1150,30 @@ function abstract_invoke(interp::AbstractInterpreter, argtypes::Vector{Any}, sv: nargtype = Tuple{ft, nargtype.parameters...} argtype = Tuple{ft, argtype.parameters...} result = findsup(types, method_table(interp)) - if result === nothing - return CallMeta(Any, false) - end + result === nothing && return CallMeta(Any, false) method, valid_worlds = result update_valid_age!(sv, valid_worlds) (ti, env::SimpleVector) = ccall(:jl_type_intersection_with_env, Any, (Any, Any), nargtype, method.sig)::SimpleVector - rt, edge = typeinf_edge(interp, method, ti, env, sv) + (; rt, edge) = result = abstract_call_method(interp, method, ti, env, false, sv) edge !== nothing && add_backedge!(edge::MethodInstance, sv) - return CallMeta(rt, InvokeCallInfo(MethodMatch(ti, env, method, argtype <: method.sig))) + match = MethodMatch(ti, env, method, argtype <: method.sig) + # try constant propagation with manual inlinings of some of the heuristics + # since some checks within `abstract_call_method_with_const_args` seem a bit costly + const_prop_entry_heuristic(interp, result, sv) || return CallMeta(rt, InvokeCallInfo(match, nothing)) + argtypes′ = argtypes[4:end] + const_prop_argument_heuristic(interp, argtypes′) || const_prop_rettype_heuristic(interp, rt) || return CallMeta(rt, InvokeCallInfo(match, nothing)) + pushfirst!(argtypes′, ft) + # # typeintersect might have narrowed signature, but the accuracy gain doesn't seem worth the cost involved with the lattice comparisons + # for i in 1:length(argtypes′) + # t, a = ti.parameters[i], argtypes′[i] + # argtypes′[i] = t ⊑ a ? t : a + # end + const_rt, const_result = abstract_call_method_with_const_args(interp, result, argtype_to_function(ft′), argtypes′, match, sv, false) + if const_rt !== rt && const_rt ⊑ rt + return CallMeta(const_rt, InvokeCallInfo(match, const_result)) + else + return CallMeta(rt, InvokeCallInfo(match, nothing)) + end end # call where the function is known exactly @@ -1291,17 +1307,12 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ sv::InferenceState, max_methods::Int = InferenceParams(interp).MAX_METHODS) #print("call ", e.args[1], argtypes, "\n\n") ft = argtypes[1] - if isa(ft, Const) - f = ft.val - elseif isconstType(ft) - f = ft.parameters[1] - elseif isa(ft, DataType) && isdefined(ft, :instance) - f = ft.instance - elseif isa(ft, PartialOpaque) + f = argtype_to_function(ft) + if isa(ft, PartialOpaque) return abstract_call_opaque_closure(interp, ft, argtypes[2:end], sv) elseif isa(unwrap_unionall(ft), DataType) && unwrap_unionall(ft).name === typename(Core.OpaqueClosure) return CallMeta(rewrap_unionall(unwrap_unionall(ft).parameters[2], ft), false) - else + elseif f === nothing # non-constant function, but the number of arguments is known # and the ft is not a Builtin or IntrinsicFunction if typeintersect(widenconst(ft), Union{Builtin, Core.OpaqueClosure}) != Union{} @@ -1313,6 +1324,18 @@ function abstract_call(interp::AbstractInterpreter, fargs::Union{Nothing,Vector{ return abstract_call_known(interp, f, fargs, argtypes, sv, max_methods) end +function argtype_to_function(@nospecialize(ft)) + if isa(ft, Const) + return ft.val + elseif isconstType(ft) + return ft.parameters[1] + elseif isa(ft, DataType) && isdefined(ft, :instance) + return ft.instance + else + return nothing + end +end + function sp_type_rewrap(@nospecialize(T), linfo::MethodInstance, isreturn::Bool) isref = false if T === Bottom diff --git a/base/compiler/ssair/inlining.jl b/base/compiler/ssair/inlining.jl index b9eb382ffaa7e..78edef88439e9 100644 --- a/base/compiler/ssair/inlining.jl +++ b/base/compiler/ssair/inlining.jl @@ -1049,12 +1049,12 @@ is_builtin(s::Signature) = isa(s.f, Builtin) || s.ft ⊑ Builtin -function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallInfo, +function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, (; match, result)::InvokeCallInfo, state::InliningState, todo::Vector{Pair{Int, Any}}) stmt = ir.stmts[idx][:inst] calltype = ir.stmts[idx][:type] - if !info.match.fully_covers + if !match.fully_covers # TODO: We could union split out the signature check and continue on return nothing end @@ -1064,7 +1064,17 @@ function inline_invoke!(ir::IRCode, idx::Int, sig::Signature, info::InvokeCallIn atypes = atypes[4:end] pushfirst!(atypes, atype0) - result = analyze_method!(info.match, atypes, state, calltype) + if isa(result, InferenceResult) + item = InliningTodo(result, atypes, calltype) + validate_sparams(item.mi.sparam_vals) || return nothing + if argtypes_to_type(atypes) <: item.mi.def.sig + state.mi_cache !== nothing && (item = resolve_todo(item, state)) + handle_single_case!(ir, stmt, idx, item, true, todo) + return nothing + end + end + + result = analyze_method!(match, atypes, state, calltype) handle_single_case!(ir, stmt, idx, result, true, todo) return nothing end diff --git a/base/compiler/stmtinfo.jl b/base/compiler/stmtinfo.jl index ad7e8886b6cce..a6ffee299c4f5 100644 --- a/base/compiler/stmtinfo.jl +++ b/base/compiler/stmtinfo.jl @@ -108,6 +108,7 @@ method being processed. """ struct InvokeCallInfo match::MethodMatch + result::Union{Nothing,InferenceResult} end struct OpaqueClosureCallInfo diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index d5bad34e38322..008e6ff0d6997 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3355,3 +3355,43 @@ let Expr(:opaque_closure_method, nothing, 2, LineNumberNode(0, nothing), ci)))(true, 1.0) @test Base.return_types(oc, Tuple{}) == Any[Float64] end + +@testset "constant prop' on `invoke` calls" begin + m = Module() + + # simple cases + @eval m begin + f(a::Any, sym::Bool) = sym ? Any : :any + f(a::Number, sym::Bool) = sym ? Number : :number + end + @test (@eval m Base.return_types((Any,)) do a + Base.@invoke f(a::Any, true::Bool) + end) == Any[Type{Any}] + @test (@eval m Base.return_types((Any,)) do a + Base.@invoke f(a::Number, true::Bool) + end) == Any[Type{Number}] + @test (@eval m Base.return_types((Any,)) do a + Base.@invoke f(a::Any, false::Bool) + end) == Any[Symbol] + @test (@eval m Base.return_types((Any,)) do a + Base.@invoke f(a::Number, false::Bool) + end) == Any[Symbol] + + # https://github.com/JuliaLang/julia/issues/41024 + @eval m begin + # mixin, which expects common field `x::Int` + abstract type AbstractInterface end + Base.getproperty(x::AbstractInterface, sym::Symbol) = + sym === :x ? getfield(x, sym)::Int : + return getfield(x, sym) # fallback + + # extended mixin, which expects additional field `y::Rational{Int}` + abstract type AbstractInterfaceExtended <: AbstractInterface end + Base.getproperty(x::AbstractInterfaceExtended, sym::Symbol) = + sym === :y ? getfield(x, sym)::Rational{Int} : + return Base.@invoke getproperty(x::AbstractInterface, sym::Symbol) + end + @test (@eval m Base.return_types((AbstractInterfaceExtended,)) do x + x.x + end) == Any[Int] +end