From b4b4d212e539883b4ffd98b75521915b03421777 Mon Sep 17 00:00:00 2001 From: Jarrett Revels Date: Mon, 16 Apr 2018 19:07:56 -0400 Subject: [PATCH] enable/improve constant propagation through varargs methods - Store varargs type information in the InferenceResult object, such that the info can be used during inference/optimization - Hack in a more precise return type for getfield of a vararg tuple. Ideally, we would handle this by teaching inference to track the types of the individual fields of a Tuple, which would make this unnecessary, but until then, this hack is helpful. - Spoof parents as well as children during recursion limiting, so that higher degree cycles are appropriately spoofed - A broadcast test marked as broken is now no longer broken, presumably due to the optimizations in this commit - Fix relationship between depth/mindepth in limit_type_size/is_derived_type. The relationship should have been inverse over the domain in which they overlap, but was not maintained consistently. An example of problematic case was: t = Tuple{X,X} where X<:Tuple{Tuple{Int64,Vararg{Int64,N} where N},Tuple{Int64,Vararg{Int64,N} where N}} c = Tuple{X,X} where X<:Tuple{Int64,Vararg{Int64,N} where N} because is_derived_type was computing the depth of usage rather than the depth of definition. This change thus makes the depth/mindepth calculations more consistent, and causes the limiting heuristic to return strictly wider types than it did before. - Move the optimizer's "varargs types to tuple type" rewrite to after cache lookup.Inference is populating the InferenceResult cache using the varargs form, so the optimizer needs to do the lookup before writing the atypes in order to avoid cache misses. Co-authored-by: Jameson Nash Co-authored-by: Keno Fischer --- base/compiler/abstractinterpretation.jl | 43 +++++++++++++++---- base/compiler/inferenceresult.jl | 50 ++++++++++++++++------ base/compiler/inferencestate.jl | 16 ++----- base/compiler/optimize.jl | 33 +++++++------- base/compiler/tfuncs.jl | 2 +- base/compiler/typelimits.jl | 14 +++--- base/compiler/typeutils.jl | 14 ------ base/essentials.jl | 8 ++++ stdlib/SparseArrays/test/higherorderfns.jl | 11 +---- test/compiler/compiler.jl | 40 +++++++++++++++++ 10 files changed, 148 insertions(+), 83 deletions(-) diff --git a/base/compiler/abstractinterpretation.jl b/base/compiler/abstractinterpretation.jl index 68299eb5385cb..e1549b1dd3746 100644 --- a/base/compiler/abstractinterpretation.jl +++ b/base/compiler/abstractinterpretation.jl @@ -116,8 +116,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector method.isva && (nargs -= 1) length(argtypes) >= nargs || return Any # probably limit_tuple_type made this non-matching method apparently match haveconst = false - for i in 1:nargs - a = argtypes[i] + for a in argtypes if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val)) # have new information from argtypes that wasn't available from the signature haveconst = true @@ -144,8 +143,7 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector tm = _topmod(sv) if !istopfunction(tm, f, :getproperty) && !istopfunction(tm, f, :setproperty!) # in this case, see if all of the arguments are constants - for i in 1:nargs - a = argtypes[i] + for a in argtypes if !isa(a, Const) && !isconstType(a) return Any end @@ -156,6 +154,17 @@ function abstract_call_method_with_const_args(@nospecialize(f), argtypes::Vector if inf_result === nothing inf_result = InferenceResult(code) atypes = get_argtypes(inf_result) + if method.isva + vargs = argtypes[(nargs + 1):end] + for i in 1:length(vargs) + a = vargs[i] + if i > length(inf_result.vargs) + push!(inf_result.vargs, a) + elseif a isa Const + inf_result.vargs[i] = a + end + end + end for i in 1:nargs a = argtypes[i] if a isa Const @@ -187,7 +196,13 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl cyclei = 0 infstate = sv edgecycle = false + # The `method_for_inference_heuristics` will expand the given method's generator if + # necessary in order to retrieve this field from the generated `CodeInfo`, if it exists. + # The other `CodeInfo`s we inspect will already have this field inflated, so we just + # access it directly instead (to avoid regeneration). method2 = method_for_inference_heuristics(method, sig, sparams, sv.params.world) # Union{Method, Nothing} + sv_method2 = sv.src.method_for_inference_limit_heuristics # limit only if user token match + sv_method2 isa Method || (sv_method2 = nothing) # Union{Method, Nothing} while !(infstate === nothing) infstate = infstate::InferenceState if method === infstate.linfo.def @@ -208,7 +223,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl for parent in infstate.callers_in_cycle # check in the cycle list first # all items in here are mutual parents of all others - if parent.linfo.def === sv.linfo.def + parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} + if parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 topmost = infstate edgecycle = true break @@ -218,7 +235,9 @@ function abstract_call_method(method::Method, @nospecialize(sig), sparams::Simpl # then check the parent link if topmost === nothing && parent !== nothing parent = parent::InferenceState - if parent.cached && parent.linfo.def === sv.linfo.def + parent_method2 = parent.src.method_for_inference_limit_heuristics # limit only if user token match + parent_method2 isa Method || (parent_method2 = nothing) # Union{Method, Nothing} + if parent.cached && parent.linfo.def === sv.linfo.def && sv_method2 === parent_method2 topmost = infstate edgecycle = true end @@ -315,8 +334,8 @@ function precise_container_type(@nospecialize(arg), @nospecialize(typ), vtypes:: end arg = ssa_def_expr(arg, sv) - if is_specializable_vararg_slot(arg, sv) - return Any[rewrap_unionall(p, sv.linfo.specTypes) for p in sv.vararg_type_container.parameters] + if is_specializable_vararg_slot(arg, sv.nargs, sv.result.vargs) + return sv.result.vargs end tti0 = widenconst(typ) @@ -482,7 +501,13 @@ function abstract_call(@nospecialize(f), fargs::Union{Tuple{},Vector{Any}}, argt tm = _topmod(sv) if isa(f, Builtin) || isa(f, IntrinsicFunction) rt = builtin_tfunction(f, argtypes[2:end], sv) - if (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) + if f === getfield && isa(fargs, Vector{Any}) && length(argtypes) == 3 && isa(argtypes[3], Const) && isa(argtypes[3].val, Int) && argtypes[2] ⊑ Tuple + cti = precise_container_type(fargs[2], argtypes[2], vtypes, sv) + idx = argtypes[3].val + if 1 <= idx <= length(cti) + rt = unwrapva(cti[idx]) + end + elseif (rt === Bool || (isa(rt, Const) && isa(rt.val, Bool))) && isa(fargs, Vector{Any}) # perform very limited back-propagation of type information for `is` and `isa` if f === isa a = ssa_def_expr(fargs[2], sv) diff --git a/base/compiler/inferenceresult.jl b/base/compiler/inferenceresult.jl index e533f0f13e376..6b9b2a2ba20d8 100644 --- a/base/compiler/inferenceresult.jl +++ b/base/compiler/inferenceresult.jl @@ -5,6 +5,9 @@ const EMPTY_VECTOR = Vector{Any}() mutable struct InferenceResult linfo::MethodInstance args::Vector{Any} + vargs::Vector{Any} # Memoize vararg type info w/Consts here when calling get_argtypes + # on the InferenceResult, so that the optimizer can use this info + # later during inlining. result # ::Type, or InferenceState if WIP src::Union{CodeInfo, Nothing} # if inferred copy is available function InferenceResult(linfo::MethodInstance) @@ -13,7 +16,7 @@ mutable struct InferenceResult else result = linfo.rettype end - return new(linfo, EMPTY_VECTOR, result, nothing) + return new(linfo, EMPTY_VECTOR, Any[], result, nothing) end end @@ -31,7 +34,35 @@ function get_argtypes(result::InferenceResult) end vararg_type = Tuple else - vararg_type = rewrap(tupleparam_tail(atypes, nargs), linfo.specTypes) + laty = length(atypes) + if nargs > laty + va = atypes[laty] + if isvarargtype(va) + new_va = rewrap_unionall(unconstrain_vararg_length(va), linfo.specTypes) + vararg_type_vec = Any[new_va] + vararg_type = Tuple{new_va} + else + vararg_type_vec = Any[] + vararg_type = Tuple{} + end + else + vararg_type_vec = Any[] + for p in atypes[nargs:laty] + p = isvarargtype(p) ? unconstrain_vararg_length(p) : p + push!(vararg_type_vec, rewrap_unionall(p, linfo.specTypes)) + end + vararg_type = tuple_tfunc(Tuple{vararg_type_vec...}) + for i in 1:length(vararg_type_vec) + atyp = vararg_type_vec[i] + if isa(atyp, DataType) && isdefined(atyp, :instance) + # replace singleton types with their equivalent Const object + vararg_type_vec[i] = Const(atyp.instance) + elseif isconstType(atyp) + vararg_type_vec[i] = Const(atyp.parameters[1]) + end + end + end + result.vargs = vararg_type_vec end args[nargs] = vararg_type nargs -= 1 @@ -80,19 +111,12 @@ function cache_lookup(code::MethodInstance, argtypes::Vector{Any}, cache::Vector for cache_code in cache # try to search cache first cache_args = cache_code.args - if cache_code.linfo === code && length(cache_args) >= nargs + cache_vargs = cache_code.vargs + if cache_code.linfo === code && length(argtypes) === (length(cache_vargs) + nargs) cache_match = true - # verify that the trailing args (va) aren't Const - for i in (nargs + 1):length(cache_args) - if isa(cache_args[i], Const) - cache_match = false - break - end - end - cache_match || continue - for i in 1:nargs + for i in 1:length(argtypes) a = argtypes[i] - ca = cache_args[i] + ca = i <= nargs ? cache_args[i] : cache_vargs[i - nargs] # verify that all Const argument types match between the call and cache if (isa(a, Const) || isa(ca, Const)) && !(a === ca) cache_match = false diff --git a/base/compiler/inferencestate.jl b/base/compiler/inferencestate.jl index ba87c4b36c053..583ca82727b74 100644 --- a/base/compiler/inferencestate.jl +++ b/base/compiler/inferencestate.jl @@ -30,7 +30,6 @@ mutable struct InferenceState # ssavalue sparsity and restart info ssavalue_uses::Vector{BitSet} ssavalue_defs::Vector{LineNum} - vararg_type_container #::Type backedges::Vector{Tuple{InferenceState, LineNum}} # call-graph backedges connecting from callee to caller callers_in_cycle::Vector{InferenceState} @@ -102,19 +101,11 @@ mutable struct InferenceState # initial types nslots = length(src.slotnames) argtypes = get_argtypes(result) - vararg_type_container = nothing nargs = length(argtypes) s_argtypes = VarTable(undef, nslots) src.slottypes = Vector{Any}(undef, nslots) for i in 1:nslots at = (i > nargs) ? Bottom : argtypes[i] - if !toplevel && linfo.def.isva && i == nargs - if !(at == Tuple) # would just be a no-op - vararg_type_container = unwrap_unionall(at) - vararg_type = tuple_tfunc(vararg_type_container) # returns a Const object, if applicable - at = rewrap(vararg_type, linfo.specTypes) - end - end s_argtypes[i] = VarState(at, i > nargs) src.slottypes[i] = at end @@ -152,7 +143,7 @@ mutable struct InferenceState nargs, s_types, s_edges, Union{}, W, 1, n, cur_hand, handler_at, n_handlers, - ssavalue_uses, ssavalue_defs, vararg_type_container, + ssavalue_uses, ssavalue_defs, Vector{Tuple{InferenceState,LineNum}}(), # backedges Vector{InferenceState}(), # callers_in_cycle #=parent=#nothing, @@ -238,9 +229,8 @@ function add_mt_backedge!(mt::Core.MethodTable, @nospecialize(typ), caller::Infe nothing end -function is_specializable_vararg_slot(@nospecialize(arg), sv::InferenceState) - return (isa(arg, Slot) && slot_id(arg) == sv.nargs && - isa(sv.vararg_type_container, DataType)) +function is_specializable_vararg_slot(@nospecialize(arg), nargs, vargs) + return (isa(arg, Slot) && slot_id(arg) == nargs && !isempty(vargs)) end function print_callstack(sv::InferenceState) diff --git a/base/compiler/optimize.jl b/base/compiler/optimize.jl index eb213270fe0ad..5943037cf33c5 100644 --- a/base/compiler/optimize.jl +++ b/base/compiler/optimize.jl @@ -6,7 +6,7 @@ mutable struct OptimizationState linfo::MethodInstance - vararg_type_container #::Type + result_vargs::Vector{Any} backedges::Vector{Any} src::CodeInfo mod::Module @@ -23,7 +23,7 @@ mutable struct OptimizationState end src = frame.src next_label = max(label_counter(src.code), length(src.code)) + 10 - return new(frame.linfo, frame.vararg_type_container, + return new(frame.linfo, frame.result.vargs, s_edges::Vector{Any}, src, frame.mod, frame.nargs, next_label, frame.min_valid, frame.max_valid, @@ -53,8 +53,8 @@ mutable struct OptimizationState nargs = 0 end next_label = max(label_counter(src.code), length(src.code)) + 10 - vararg_type_container = nothing # if you want something more accurate, set it yourself :P - return new(linfo, vararg_type_container, + result_vargs = Any[] # if you want something more accurate, set it yourself :P + return new(linfo, result_vargs, s_edges::Vector{Any}, src, inmodule, nargs, next_label, @@ -100,11 +100,6 @@ function add_backedge!(li::MethodInstance, caller::OptimizationState) nothing end -function is_specializable_vararg_slot(@nospecialize(arg), sv::OptimizationState) - return (isa(arg, Slot) && slot_id(arg) == sv.nargs && - isa(sv.vararg_type_container, DataType)) -end - ########### # structs # ########### @@ -1333,14 +1328,22 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector if invoke_api(linfo) == 2 # in this case function can be inlined to a constant add_backedge!(linfo, sv) + # XXX: @vtjnash thinks this should be `argexprs0`, but doing so exposes a + # downstream optimizer problem that breaks tests, so we're going to avoid + # changing it for now. ref + # https://github.com/JuliaLang/julia/pull/26826#issuecomment-386381103 return inline_as_constant(linfo.inferred_const, argexprs, sv, invoke_data) end - # see if the method has a InferenceResult in the current cache + # See if the method has a InferenceResult in the current cache # or an existing inferred code info store in `.inferred` + # + # Above, we may have rewritten trailing varargs in `atypes` to a tuple type. However, + # inference populates the cache with the pre-rewrite version (`atypes0`), so here, we + # check against that instead. haveconst = false - for i in 1:length(atypes) - a = atypes[i] + for i in 1:length(atypes0) + a = atypes0[i] if isa(a, Const) && !isdefined(typeof(a.val), :instance) && !(isa(a.val, Type) && issingletontype(a.val)) # have new information from argtypes that wasn't available from the signature haveconst = true @@ -1348,7 +1351,7 @@ function inlineable(@nospecialize(f), @nospecialize(ft), e::Expr, atypes::Vector end end if haveconst - inf_result = cache_lookup(linfo, atypes, sv.params.cache) # Union{Nothing, InferenceResult} + inf_result = cache_lookup(linfo, atypes0, sv.params.cache) # Union{Nothing, InferenceResult} else inf_result = nothing end @@ -2003,8 +2006,8 @@ function inline_call(e::Expr, sv::OptimizationState, stmts::Vector{Any}, boundsc tmpv = newvar!(sv, t) push!(newstmts, Expr(:(=), tmpv, aarg)) end - if is_specializable_vararg_slot(aarg, sv) - tp = sv.vararg_type_container.parameters + if is_specializable_vararg_slot(aarg, sv.nargs, sv.result_vargs) + tp = sv.result_vargs else tp = t.parameters end diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index a291efdf61962..e97b2a22baca9 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -223,7 +223,7 @@ add_tfunc(===, 2, 2, end return Bool end, 1) -function isdefined_tfunc(args...) +function isdefined_tfunc(@nospecialize(args...)) arg1 = args[1] if isa(arg1, Const) a1 = typeof(arg1.val) diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index 55183addd73f8..6ca2f0dddf9e6 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -50,15 +50,12 @@ function is_derived_type(@nospecialize(t), @nospecialize(c), mindepth::Int) if t === c return mindepth == 0 end - if isa(c, TypeVar) - # see if it is replacing a TypeVar upper bound with something simpler - return is_derived_type(t, c.ub, mindepth) - elseif isa(c, Union) + if isa(c, Union) # see if it is one of the elements of the union return is_derived_type(t, c.a, mindepth + 1) || is_derived_type(t, c.b, mindepth + 1) elseif isa(c, UnionAll) # see if it is derived from the body - return is_derived_type(t, c.body, mindepth) + return is_derived_type(t, c.var.ub, mindepth) || is_derived_type(t, c.body, mindepth + 1) elseif isa(c, DataType) if isa(t, DataType) # see if it is one of the supertypes of a parameter @@ -96,7 +93,8 @@ function is_derived_type_from_any(@nospecialize(t), sources::SimpleVector, minde return false end -# type vs. comparison or which was derived from source +# The goal of this function is to return a type of greater "size" and less "complexity" than +# both `t` or `c` over the lattice defined by `sources`, `depth`, and `allowed_tuplelen`. function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVector, depth::Int, allowed_tuplelen::Int) if t === c return t # quick egal test @@ -140,9 +138,9 @@ function _limit_type_size(@nospecialize(t), @nospecialize(c), sources::SimpleVec lb = Bottom end v2 = TypeVar(tv.name, lb, ub) - return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth + 1, allowed_tuplelen)) + return UnionAll(v2, _limit_type_size(t{v2}, c{v2}, sources, depth, allowed_tuplelen)) end - tbody = _limit_type_size(t.body, c, sources, depth + 1, allowed_tuplelen) + tbody = _limit_type_size(t.body, c, sources, depth, allowed_tuplelen) tbody === t.body && return t return UnionAll(t.var, tbody) elseif isa(c, UnionAll) diff --git a/base/compiler/typeutils.jl b/base/compiler/typeutils.jl index ed74e90f6e26e..f598852b7f2de 100644 --- a/base/compiler/typeutils.jl +++ b/base/compiler/typeutils.jl @@ -99,20 +99,6 @@ function tuple_tail_elem(@nospecialize(init), ct) return Vararg{widenconst(foldl((a, b) -> tmerge(a, tvar_extent(unwrapva(b))), init, ct))} end -# t[n:end] -function tupleparam_tail(t::SimpleVector, n) - lt = length(t) - if n > lt - va = t[lt] - if isvarargtype(va) - # assumes that we should never see Vararg{T, x}, where x is a constant (should be guaranteed by construction) - return Tuple{va} - end - return Tuple{} - end - return Tuple{t[n:lt]...} -end - # take a Tuple where one or more parameters are Unions # and return an array such that those Unions are removed # and `Union{return...} == ty` diff --git a/base/essentials.jl b/base/essentials.jl index 7faea542cf922..4e8a1bd7c40f8 100644 --- a/base/essentials.jl +++ b/base/essentials.jl @@ -214,6 +214,14 @@ function unwrapva(@nospecialize(t)) return isvarargtype(t2) ? rewrap_unionall(t2.parameters[1], t) : t end +function unconstrain_vararg_length(@nospecialize(va)) + # construct a new Vararg type where its length is unconstrained, + # but its element type still captures any dependencies the input + # element type may have had on the input length + T = unwrap_unionall(va).parameters[1] + return rewrap_unionall(Vararg{T}, va) +end + typename(a) = error("typename does not apply to this type") typename(a::DataType) = a.name function typename(a::Union) diff --git a/stdlib/SparseArrays/test/higherorderfns.jl b/stdlib/SparseArrays/test/higherorderfns.jl index 8744f80a39dbe..84f65025e7e42 100644 --- a/stdlib/SparseArrays/test/higherorderfns.jl +++ b/stdlib/SparseArrays/test/higherorderfns.jl @@ -266,16 +266,7 @@ end # --> test broadcast! entry point / not zero-preserving op fQ = broadcast(f, fX, fY, fZ); Q = sparse(fQ) broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated - @test_broken (@allocated broadcast!(f, Q, X, Y, Z)) == 0 - broadcast!(f, Q, X, Y, Z); Q = sparse(fQ) # warmup for @allocated - @test (@allocated broadcast!(f, Q, X, Y, Z)) <= 16 - # the preceding test allocates 16 bytes in the entry point for broadcast!, but - # none of the earlier tests of the same code path allocate. no allocation shows - # up with --track-allocation=user. allocation shows up on the first line of the - # entry point for broadcast! with --track-allocation=all, but that first line - # almost certainly should not allocate. so not certain what's going on. - # additional info: occurs for broadcast!(f, Z, X) for Z and X of different - # shape, but not for Z and X of the same shape. + @test (@allocated broadcast!(f, Q, X, Y, Z)) == 0 @test broadcast!(f, Q, X, Y, Z) == sparse(broadcast!(f, fQ, fX, fY, fZ)) # --> test shape checks for both broadcast and broadcast! entry points # TODO strengthen this test, avoiding dependence on checking whether diff --git a/test/compiler/compiler.jl b/test/compiler/compiler.jl index 4c3a1a26ae974..7e078f2b2452d 100644 --- a/test/compiler/compiler.jl +++ b/test/compiler/compiler.jl @@ -1524,3 +1524,43 @@ f26172(v) = Val{length(Base.tail(ntuple(identity, v)))}() # Val(M-1) g26172(::Val{0}) = () g26172(v) = (nothing, g26172(f26172(v))...) @test @inferred(g26172(Val(10))) === ntuple(_ -> nothing, 10) + +# 26826 constant prop through varargs + +struct Foo26826{A,B} + a::A + b::B +end + +x26826 = rand() + +apply26826(f, args...) = f(args...) + +f26826(x) = apply26826(Base.getproperty, Foo26826(1, x), :b) +# We use getproperty to drive these tests because it requires constant +# propagation in order to lower to a well-inferred getfield call. + +@test @inferred(f26826(x26826)) === x26826 + +getfield26826(x, args...) = Base.getproperty(x, getfield(args, 2)) + +g26826(x) = getfield26826(x, :a, :b) + +@test @inferred(g26826(Foo26826(1, x26826))) === x26826 + +# Somewhere in here should be a single getfield call, and it should be inferred as Float64. +# If this test is broken (especially if inference is getting a correct, but loose result, +# like a Union) then it's potentially an indication that the optimizer isn't hitting the +# InferenceResult cache properly for varargs methods. +typed_code = Core.Compiler.code_typed(f26826, (Float64,))[1].first.code +found_well_typed_getfield_call = false +for stmnt in typed_code + if Meta.isexpr(stmnt, :(=)) && Meta.isexpr(stmnt.args[2], :call) + lhs = stmnt.args[2] + if lhs.args[1] == GlobalRef(Base, :getfield) && lhs.typ === Float64 + global found_well_typed_getfield_call = true + end + end +end + +@test found_well_typed_getfield_call