diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 68299eb5385cb..e1549b1dd3746 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -116,8 +116,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector method.isva && (nargs -= 1) length(argtypes) >= nargs || return Any # probably limit_tuple_type made this non-matching method apparently match haveconst = false - for i in 1:nargs - a = argtypes[i] + for a in argtypes if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val)) # have new information from argtypes that wasn't available from the signature haveconst = true @@ -144,8 +143,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector tm = _topmod(sv) if !istopfunction(tm, f, :getproperty) && !istopfunction(tm, f, :setproperty!) # in this case, see if all of the arguments are constants - for i in 1:nargs - a = argtypes[i] + for a in argtypes if !isa(a, Const) && !isconstType(a) return Any end @@ -156,6 +154,17 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector if inf_result === nothing inf_result = InferenceResult(code) atypes = get_argtypes(inf_result) + if method.isva + vargs = argtypes[(nargs + 1):end] + for i in 1:length(vargs) + a = vargs[i] + if i > length(inf_result.vargs) + push!(inf_result.vargs, a) + elseif a isa Const + inf_result.vargs[i] = a + end + end + end for i in 1:nargs a = argtypes[i] if a isa Const @@ -187,7 +196,13 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl cyclei = 0 infstate = sv edgecycle = false + # The `method_for_inference_heuristics` will expand the given method's generator if + # necessary in order to retrieve this field from the generated `CodeInfo`, if it exists. + # The other `CodeInfo`s we inspect will already have this field inflated, so we just + # access it directly instead (to avoid regeneration). method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing} + sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match + sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing} while !(infstate === nothing) infstate = infstate::InferenceState if method === infstate.linfo.def @@ -208,7 +223,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl for parent in infstate.callers_in_cycle # check in the cycle list first # all items in here are mutual parents of all others - if parent.linfo.def === sv.linfo.def + parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} + if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 topmost = infstate edgecycle = true break @@ -218,7 +235,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl # then check the parent link if topmost === nothing && parent !== nothing parent = parent::InferenceState - if parent.cached && parent.linfo.def === sv.linfo.def + parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} + if parent.cached && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 topmost = infstate edgecycle = true end @@ -315,8 +334,8 @@ function precise_container_type(@nospecialize(arg), @nospecialize(typ), vtypes:: end arg = ssa_def_expr(arg, sv) - if is_specializable_vararg_slot(arg, sv) - return Any[rewrap_unionall(p, sv.linfo.specTypes) for p in sv.vararg_type_container.parameters] + if is_specializable_vararg_slot(arg, sv.nargs, sv.result.vargs) + return sv.result.vargs end tti0 = widenconst(typ) @@ -482,7 +501,13 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt tm = _topmod(sv) if isa(f, Builtin) || isa(f, IntrinsicFunction) rt = builtin_tfunction(f, argtypes[2:end], sv) - if (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) + if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] ⊑ Tuple + cti = precise_container_type(fargs[2], argtypes[2], vtypes, sv) + idx = argtypes[3].val + if 1 <= idx <= length(cti) + rt = unwrapva(cti[idx]) + end + elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) # perform very limited back-propagation of type information for `is` and `isa` if f === isa a = ssa_def_expr(fargs[2], sv) diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index e533f0f13e376..6b9b2a2ba20d8 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -5,6 +5,9 @@ const EMPTY_VECTOR = Vector{Any}() mutable struct InferenceResult linfo::MethodInstance args::Vector{Any} + vargs::Vector{Any} # Memoize vararg type info w/Consts here when calling get_argtypes + # on the InferenceResult, so that the optimizer can use this info + # later during inlining. result # ::Type, or InferenceState if WIP src::Union{CodeInfo, Nothing} # if inferred copy is available function InferenceResult(linfo::MethodInstance) @@ -13,7 +16,7 @@ mutable struct InferenceResult else result = linfo.rettype end - return new(linfo, EMPTY_VECTOR, result, nothing) + return new(linfo, EMPTY_VECTOR, Any[], result, nothing) end end @@ -31,7 +34,35 @@ function get_argtypes(result::InferenceResult) end vararg_type = Tuple else - vararg_type = rewrap(tupleparam_tail(atypes, nargs), linfo.specTypes) + laty = length(atypes) + if nargs > laty + va = atypes[laty] + if isvarargtype(va) + new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes) + vararg_type_vec = Any[new_va] + vararg_type = Tuple{new_va} + else + vararg_type_vec = Any[] + vararg_type = Tuple{} + end + else + vararg_type_vec = Any[] + for p in atypes[nargs:laty] + p = isvarargtype(p) ? unconstrain_vararg_length(p) : p + push!(vararg_type_vec, rewrap_unionall(p, linfo.specTypes)) + end + vararg_type = tuple_tfunc(Tuple{vararg_type_vec...}) + for i in 1:length(vararg_type_vec) + atyp = vararg_type_vec[i] + if isa(atyp, DataType) && isdefined(atyp, :instance) + # replace singleton types with their equivalent Const object + vararg_type_vec[i] = Const(atyp.instance) + elseif isconstType(atyp) + vararg_type_vec[i] = Const(atyp.parameters[1]) + end + end + end + result.vargs = vararg_type_vec end args[nargs] = vararg_type nargs -= 1 @@ -80,19 +111,12 @@ function cache_lookup(code::MethodInstance, argtypes::Vector{Any}, cache::Vector for cache_code in cache # try to search cache first cache_args = cache_code.args - if cache_code.linfo === code && length(cache_args) >= nargs + cache_vargs = cache_code.vargs + if cache_code.linfo === code && length(argtypes) === (length(cache_vargs) + nargs) cache_match = true - # verify that the trailing args (va) aren't Const - for i in (nargs + 1):length(cache_args) - if isa(cache_args[i], Const) - cache_match = false - break - end - end - cache_match || continue - for i in 1:nargs + for i in 1:length(argtypes) a = argtypes[i] - ca = cache_args[i] + ca = i <= nargs ? cache_args[i] : cache_vargs[i - nargs] # verify that all Const argument types match between the call and cache if (isa(a, Const) || isa(ca, Const)) && !(a === ca) cache_match = false diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index ba87c4b36c053..583ca82727b74 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -30,7 +30,6 @@ mutable struct InferenceState # ssavalue sparsity and restart info ssavalue_uses::Vector{BitSet} ssavalue_defs::Vector{LineNum} - vararg_type_container #::Type backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller callers_in_cycle::Vector{InferenceState} @@ -102,19 +101,11 @@ mutable struct InferenceState # initial types nslots = length(src.slotnames) argtypes = get_argtypes(result) - vararg_type_container = nothing nargs = length(argtypes) s_argtypes = VarTable(undef, nslots) src.slottypes = Vector{Any}(undef, nslots) for i in 1:nslots at = (i > nargs) ? Bottom : argtypes[i] - if !toplevel && linfo.def.isva && i == nargs - if !(at == Tuple) # would just be a no-op - vararg_type_container = unwrap_unionall(at) - vararg_type = tuple_tfunc(vararg_type_container) # returns a Const object, if applicable - at = rewrap(vararg_type, linfo.specTypes) - end - end s_argtypes[i] = VarState(at, i > nargs) src.slottypes[i] = at end @@ -152,7 +143,7 @@ mutable struct InferenceState nargs, s_types, s_edges, Union{}, W, 1, n, cur_hand, handler_at, n_handlers, - ssavalue_uses, ssavalue_defs, vararg_type_container, + ssavalue_uses, ssavalue_defs, Vector{Tuple{InferenceState,LineNum}}(), # backedges Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, @@ -238,9 +229,8 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe nothing end -function is_specializable_vararg_slot(@nospecialize(arg), sv::InferenceState) - return (isa(arg, Slot) && slot_id(arg) == sv.nargs && - isa(sv.vararg_type_container, DataType)) +function is_specializable_vararg_slot(@nospecialize(arg), nargs, vargs) + return (isa(arg, Slot) && slot_id(arg) == nargs && !isempty(vargs)) end function print_callstack(sv::InferenceState) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index eb213270fe0ad..5943037cf33c5 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -6,7 +6,7 @@ mutable struct OptimizationState linfo::MethodInstance - vararg_type_container #::Type + result_vargs::Vector{Any} backedges::Vector{Any} src::CodeInfo mod::Module @@ -23,7 +23,7 @@ mutable struct OptimizationState end src = frame.src next_label = max(label_counter(src.code), length(src.code)) + 10 - return new(frame.linfo, frame.vararg_type_container, + return new(frame.linfo, frame.result.vargs, s_edges::Vector{Any}, src, frame.mod, frame.nargs, next_label, frame.min_valid, frame.max_valid, @@ -53,8 +53,8 @@ mutable struct OptimizationState nargs = 0 end next_label = max(label_counter(src.code), length(src.code)) + 10 - vararg_type_container = nothing # if you want something more accurate, set it yourself :P - return new(linfo, vararg_type_container, + result_vargs = Any[] # if you want something more accurate, set it yourself :P + return new(linfo, result_vargs, s_edges::Vector{Any}, src, inmodule, nargs, next_label, @@ -100,11 +100,6 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState) nothing end -function is_specializable_vararg_slot(@nospecialize(arg), sv::OptimizationState) - return (isa(arg, Slot) && slot_id(arg) == sv.nargs && - isa(sv.vararg_type_container, DataType)) -end - ########### # structs # ########### @@ -1333,14 +1328,22 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector if invoke_api(linfo) == 2 # in this case function can be inlined to a constant add_backedge!(linfo, sv) + # XXX: @vtjnash thinks this should be `argexprs0`, but doing so exposes a + # downstream optimizer problem that breaks tests, so we're going to avoid + # changing it for now. ref + # https://github.com/JuliaLang/julia/pull/26826#issuecomment-386381103 return inline_as_constant(linfo.inferred_const, argexprs, sv, invoke_data) end - # see if the method has a InferenceResult in the current cache + # See if the method has a InferenceResult in the current cache # or an existing inferred code info store in `.inferred` + # + # Above, we may have rewritten trailing varargs in `atypes` to a tuple type. However, + # inference populates the cache with the pre-rewrite version (`atypes0`), so here, we + # check against that instead. haveconst = false - for i in 1:length(atypes) - a = atypes[i] + for i in 1:length(atypes0) + a = atypes0[i] if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val)) # have new information from argtypes that wasn't available from the signature haveconst = true @@ -1348,7 +1351,7 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector end end if haveconst - inf_result = cache_lookup(linfo, atypes, sv.params.cache) # Union{Nothing, InferenceResult} + inf_result = cache_lookup(linfo, atypes0, sv.params.cache) # Union{Nothing, InferenceResult} else inf_result = nothing end @@ -2003,8 +2006,8 @@ function inline_call(e::Expr, sv::OptimizationState, stmts::Vector{Any}, boundsc tmpv = newvar!(sv, t) push!(newstmts, Expr(:(=), tmpv, aarg)) end - if is_specializable_vararg_slot(aarg, sv) - tp = sv.vararg_type_container.parameters + if is_specializable_vararg_slot(aarg, sv.nargs, sv.result_vargs) + tp = sv.result_vargs else tp = t.parameters end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index a291efdf61962..e97b2a22baca9 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -223,7 +223,7 @@ add_tfunc(===, 2, 2, end return Bool end, 1) -function isdefined_tfunc(args...) +function isdefined_tfunc(@nospecialize(args...)) arg1 = args[1] if isa(arg1, Const) a1 = typeof(arg1.val) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 55183addd73f8..6ca2f0dddf9e6 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -50,15 +50,12 @@ function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int) if t === c return mindepth == 0 end - if isa(c, TypeVar) - # see if it is replacing a TypeVar upper bound with something simpler - return is_derived_type(t, c.ub, mindepth) - elseif isa(c, Union) + if isa(c, Union) # see if it is one of the elements of the union return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1) elseif isa(c, UnionAll) # see if it is derived from the body - return is_derived_type(t, c.body, mindepth) + return is_derived_type(t, c.var.ub, mindepth) || is_derived_type(t, c.body, mindepth + 1) elseif isa(c, DataType) if isa(t, DataType) # see if it is one of the supertypes of a parameter @@ -96,7 +93,8 @@ function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, minde return false end -# type vs. comparison or which was derived from source +# The goal of this function is to return a type of greater "size" and less "complexity" than +# both `t` or `c` over the lattice defined by `sources`, `depth`, and `allowed_tuplelen`. function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int) if t === c return t # quick egal test @@ -140,9 +138,9 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec lb = Bottom end v2 = TypeVar(tv.name, lb, ub) - return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen)) + return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth, allowed_tuplelen)) end - tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen) + tbody = _limit_type_size(t.body, c, sources, depth, allowed_tuplelen) tbody === t.body && return t return UnionAll(t.var, tbody) elseif isa(c, UnionAll) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index ed74e90f6e26e..f598852b7f2de 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -99,20 +99,6 @@ function tuple_tail_elem(@nospecialize(init), ct) return Vararg{widenconst(foldl((a, b) -> tmerge(a, tvar_extent(unwrapva(b))), init, ct))} end -# t[n:end] -function tupleparam_tail(t::SimpleVector, n) - lt = length(t) - if n > lt - va = t[lt] - if isvarargtype(va) - # assumes that we should never see Vararg{T, x}, where x is a constant (should be guaranteed by construction) - return Tuple{va} - end - return Tuple{} - end - return Tuple{t[n:lt]...} -end - # take a Tuple where one or more parameters are Unions # and return an array such that those Unions are removed # and `Union{return...} == ty` diff --git a/base/essentials.jl b/base/essentials.jl index 7faea542cf922..4e8a1bd7c40f8 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -214,6 +214,14 @@ function unwrapva(@nospecialize(t)) return isvarargtype(t2) ? rewrap_unionall(t2.parameters[1], t) : t end +function unconstrain_vararg_length(@nospecialize(va)) + # construct a new Vararg type where its length is unconstrained, + # but its element type still captures any dependencies the input + # element type may have had on the input length + T = unwrap_unionall(va).parameters[1] + return rewrap_unionall(Vararg{T}, va) +end + typename(a) = error("typename does not apply to this type") typename(a::DataType) = a.name function typename(a::Union) diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index 8744f80a39dbe..84f65025e7e42 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -266,16 +266,7 @@ end # --> test broadcast! entry point / not zero-preserving op fQ = broadcast(f, fX, fY, fZ); Q = sparse(fQ) broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated - @test_broken (@allocated broadcast!(f, Q, X, Y, Z)) == 0 - broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated - @test (@allocated broadcast!(f, Q, X, Y, Z)) <= 16 - # the preceding test allocates 16 bytes in the entry point for broadcast!, but - # none of the earlier tests of the same code path allocate. no allocation shows - # up with --track-allocation=user. allocation shows up on the first line of the - # entry point for broadcast! with --track-allocation=all, but that first line - # almost certainly should not allocate. so not certain what's going on. - # additional info: occurs for broadcast!(f, Z, X) for Z and X of different - # shape, but not for Z and X of the same shape. + @test (@allocated broadcast!(f, Q, X, Y, Z)) == 0 @test broadcast!(f, Q, X, Y, Z) == sparse(broadcast!(f, fQ, fX, fY, fZ)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether diff --git a/test/compiler/compiler.jl b/test/compiler/compiler.jl index 4c3a1a26ae974..7e078f2b2452d 100644 --- a/test/compiler/compiler.jl +++ b/test/compiler/compiler.jl @@ -1524,3 +1524,43 @@ f26172(v) = Val{length(Base.tail(ntuple(identity, v)))}() # Val(M-1) g26172(::Val{0}) = () g26172(v) = (nothing, g26172(f26172(v))...) @test @inferred(g26172(Val(10))) === ntuple(_ -> nothing, 10) + +# 26826 constant prop through varargs + +struct Foo26826{A,B} + a::A + b::B +end + +x26826 = rand() + +apply26826(f, args...) = f(args...) + +f26826(x) = apply26826(Base.getproperty, Foo26826(1, x), :b) +# We use getproperty to drive these tests because it requires constant +# propagation in order to lower to a well-inferred getfield call. + +@test @inferred(f26826(x26826)) === x26826 + +getfield26826(x, args...) = Base.getproperty(x, getfield(args, 2)) + +g26826(x) = getfield26826(x, :a, :b) + +@test @inferred(g26826(Foo26826(1, x26826))) === x26826 + +# Somewhere in here should be a single getfield call, and it should be inferred as Float64. +# If this test is broken (especially if inference is getting a correct, but loose result, +# like a Union) then it's potentially an indication that the optimizer isn't hitting the +# InferenceResult cache properly for varargs methods. +typed_code = Core.Compiler.code_typed(f26826, (Float64,))[1].first.code +found_well_typed_getfield_call = false +for stmnt in typed_code + if Meta.isexpr(stmnt, :(=)) && Meta.isexpr(stmnt.args[2], :call) + lhs = stmnt.args[2] + if lhs.args[1] == GlobalRef(Base, :getfield) && lhs.typ === Float64 + global found_well_typed_getfield_call = true + end + end +end + +@test found_well_typed_getfield_call