From ff88fa446f44b8bcde15cea8c29549a6fff65375 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Sat, 12 Mar 2022 19:09:16 -0500 Subject: [PATCH] inference: refine PartialStruct lattice tmerge (#44404) * inference: fix tmerge lattice over issimpleenoughtype Previously we assumed only union type could have complexity that violated the tmerge lattice requirements, but other types can have that too. This lets us fix an issue with the PartialStruct comparison failing for undefined fields, mentioned in #43784. * inference: refine PartialStruct lattice tmerge Be more aggressive about merging fields to greatly accelerate convergence, but also compute anyrefine more correctly as we do now elsewhere (since #42831, a121721f975fc4105ed24ebd0ad1020d08d07a38) Move the tmeet algorithm, without changes, since it is a precise lattice operation, not a heuristic limit like tmerge. Close #43784 --- base/compiler/typelattice.jl | 86 ++++++++++++++++++++++- base/compiler/typelimits.jl | 132 +++++++++++++++++++++-------------- test/compiler/inference.jl | 28 ++++++-- 3 files changed, 189 insertions(+), 57 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index bba9f41bf64d3..f6eb92a040f2a 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -200,7 +200,7 @@ The non-strict partial order over the type inference lattice. end for i in 1:nfields(a.val) # XXX: let's handle varargs later - isdefined(a.val, i) || return false + isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T ⊑(Const(getfield(a.val, i)), b.fields[i]) || return false end return true @@ -289,6 +289,48 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b)) return a ⊑ b && b ⊑ a end +# compute typeintersect over the extended inference lattice, +# as precisely as we can, +# where v is in the extended lattice, and t is a Type. +function tmeet(@nospecialize(v), @nospecialize(t)) + if isa(v, Const) + if !has_free_typevars(t) && !isa(v.val, t) + return Bottom + end + return v + elseif isa(v, PartialStruct) + has_free_typevars(t) && return v + widev = widenconst(v) + if widev <: t + return v + end + ti = typeintersect(widev, t) + valid_as_lattice(ti) || return Bottom + @assert widev <: Tuple + new_fields = Vector{Any}(undef, length(v.fields)) + for i = 1:length(new_fields) + vfi = v.fields[i] + if isvarargtype(vfi) + new_fields[i] = vfi + else + new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) + if new_fields[i] === Bottom + return Bottom + end + end + end + return tuple_tfunc(new_fields) + elseif isa(v, Conditional) + if !(Bool <: t) + return Bottom + end + return v + end + ti = typeintersect(widenconst(v), t) + valid_as_lattice(ti) || return Bottom + return ti +end + widenconst(c::AnyConditional) = Bool widenconst((; val)::Const) = isa(val, Type) ? Type{val} : typeof(val) widenconst(m::MaybeUndef) = widenconst(m.typ) @@ -427,3 +469,45 @@ function stupdate1!(state::VarTable, change::StateUpdate) end return false end + +# compute typeintersect over the extended inference lattice, +# as precisely as we can, +# where v is in the extended lattice, and t is a Type. +function tmeet(@nospecialize(v), @nospecialize(t)) + if isa(v, Const) + if !has_free_typevars(t) && !isa(v.val, t) + return Bottom + end + return v + elseif isa(v, PartialStruct) + has_free_typevars(t) && return v + widev = widenconst(v) + if widev <: t + return v + end + ti = typeintersect(widev, t) + valid_as_lattice(ti) || return Bottom + @assert widev <: Tuple + new_fields = Vector{Any}(undef, length(v.fields)) + for i = 1:length(new_fields) + vfi = v.fields[i] + if isvarargtype(vfi) + new_fields[i] = vfi + else + new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) + if new_fields[i] === Bottom + return Bottom + end + end + end + return tuple_tfunc(new_fields) + elseif isa(v, Conditional) + if !(Bool <: t) + return Bottom + end + return v + end + ti = typeintersect(widenconst(v), t) + valid_as_lattice(ti) || return Bottom + return ti +end diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index a7989777317c3..d25c77deb6d2e 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -298,11 +298,57 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac union_count_abstract(@nospecialize(x)) = !isdispatchelem(x) function issimpleenoughtype(@nospecialize t) - t = ignorelimited(t) return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH && unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY end +# A simplified type_more_complex query over the extended lattice +# (assumes typeb ⊑ typea) +function issimplertype(@nospecialize(typea), @nospecialize(typeb)) + typea = ignorelimited(typea) + typeb = ignorelimited(typeb) + typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference + typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference + typea === typeb && return true + if typea isa PartialStruct + aty = widenconst(typea) + for i = 1:length(typea.fields) + ai = typea.fields[i] + bi = fieldtype(aty, i) + is_lattice_equal(ai, bi) && continue + tni = _typename(widenconst(ai)) + if tni isa Const + bi = (tni.val::Core.TypeName).wrapper + is_lattice_equal(ai, bi) && continue + end + bi = getfield_tfunc(typeb, Const(i)) + is_lattice_equal(ai, bi) && continue + # It is not enough for ai to be simpler than bi: it must exactly equal + # (for this, an invariant struct field, by contrast to + # type_more_complex above which handles covariant tuples). + return false + end + elseif typea isa Type + return issimpleenoughtype(typea) + # elseif typea isa Const # fall-through good + elseif typea isa Conditional # follow issubconditional query + typeb isa Const && return true + typeb isa Conditional || return false + is_same_conditionals(typea, typeb) || return false + issimplertype(typea.vtype, typeb.vtype) || return false + issimplertype(typea.elsetype, typeb.elsetype) || return false + elseif typea isa InterConditional # ibid + typeb isa Const && return true + typeb isa InterConditional || return false + is_same_conditionals(typea, typeb) || return false + issimplertype(typea.vtype, typeb.vtype) || return false + issimplertype(typea.elsetype, typeb.elsetype) || return false + elseif typea isa PartialOpaque + # TODO + end + return true +end + # pick a wider type that contains both typea and typeb, # with some limits on how "large" it can get, # but without losing too much precision in common cases @@ -310,11 +356,13 @@ end function tmerge(@nospecialize(typea), @nospecialize(typeb)) typea === Union{} && return typeb typeb === Union{} && return typea + typea === typeb && return typea + suba = typea ⊑ typeb - suba && issimpleenoughtype(typeb) && return typeb + suba && issimplertype(typeb, typea) && return typeb subb = typeb ⊑ typea suba && subb && return typea - subb && issimpleenoughtype(typea) && return typea + subb && issimplertype(typea, typeb) && return typea # type-lattice for LimitedAccuracy wrapper # the merge create a slightly narrower type than needed, but we can't @@ -404,6 +452,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) aty = widenconst(typea) bty = widenconst(typeb) if aty === bty + # must have egal here, since we do not create PartialStruct for non-concrete types typea_nfields = nfields_tfunc(typea) typeb_nfields = nfields_tfunc(typeb) isa(typea_nfields, Const) || return aty @@ -412,18 +461,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) type_nfields === typeb_nfields.val::Int || return aty type_nfields == 0 && return aty fields = Vector{Any}(undef, type_nfields) - anyconst = false + anyrefine = false for i = 1:type_nfields ai = getfield_tfunc(typea, Const(i)) bi = getfield_tfunc(typeb, Const(i)) - ity = tmerge(ai, bi) - if ai === Union{} || bi === Union{} - ity = widenconst(ity) + ft = fieldtype(aty, i) + if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft) + # Since ai===bi, the given type has no restrictions on complexity. + # and can be used to refine ft + tyi = ai + elseif is_lattice_equal(bi, ft) + tyi = bi + else + # Otherwise choose between using the fieldtype or some other simple merged type. + # The wrapper type never has restrictions on complexity, + # so try to use that to refine the estimated type too. + tni = _typename(widenconst(ai)) + if tni isa Const && tni === _typename(widenconst(bi)) + # A tmeet call may cause tyi to become complex, but since the inputs were + # strictly limited to being egal, this has no restrictions on complexity. + # (Otherwise, we would need to use <: and take the narrower one without + # intersection. See the similar comment in abstract_call_method.) + tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper) + else + # Since aty===bty, the fieldtype has no restrictions on complexity. + tyi = ft + end + end + fields[i] = tyi + if !anyrefine + anyrefine = has_nontrivial_const_info(tyi) || # constant information + tyi ⋤ ft # just a type-level information, but more precise than the declared type end - fields[i] = ity - anyconst |= has_nontrivial_const_info(ity) end - return anyconst ? PartialStruct(aty, fields) : aty + return anyrefine ? PartialStruct(aty, fields) : aty end end if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb) @@ -610,44 +681,3 @@ function tuplemerge(a::DataType, b::DataType) end return Tuple{p...} end - -# compute typeintersect over the extended inference lattice -# where v is in the extended lattice, and t is a Type -function tmeet(@nospecialize(v), @nospecialize(t)) - if isa(v, Const) - if !has_free_typevars(t) && !isa(v.val, t) - return Bottom - end - return v - elseif isa(v, PartialStruct) - has_free_typevars(t) && return v - widev = widenconst(v) - if widev <: t - return v - end - ti = typeintersect(widev, t) - valid_as_lattice(ti) || return Bottom - @assert widev <: Tuple - new_fields = Vector{Any}(undef, length(v.fields)) - for i = 1:length(new_fields) - vfi = v.fields[i] - if isvarargtype(vfi) - new_fields[i] = vfi - else - new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) - if new_fields[i] === Bottom - return Bottom - end - end - end - return tuple_tfunc(new_fields) - elseif isa(v, Conditional) - if !(Bool <: t) - return Bottom - end - return v - end - ti = typeintersect(widenconst(v), t) - valid_as_lattice(ti) || return Bottom - return ti -end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 1349e7da398fb..5d160143483af 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3999,15 +3999,33 @@ end @test ⊑(a, c) @test ⊑(b, c) - @test @eval Module() begin - const ginit = Base.ImmutableDict{Any,Any}() - Base.return_types() do - g = ginit + init = Base.ImmutableDict{Number,Number}() + a = Const(init) + b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) && ⊑(b, c) + @test c === typeof(init) + + a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) && ⊑(b, c) + @test c.fields[2] === Any # or Number + @test c.fields[3] === ComplexF64 + + b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) + @test ⊑(b, c) + @test c.fields[2] === Complex + @test c.fields[3] === Complex + + global const ginit43784 = Base.ImmutableDict{Any,Any}() + @test Base.return_types() do + g = ginit43784 while true g = Base.ImmutableDict(g, 1=>2) end end |> only === Union{} - end end # Test that purity modeling doesn't accidentally introduce new world age issues