Skip to content

Commit

Permalink
inference: refine PartialStruct lattice tmerge (#44404)
Browse files Browse the repository at this point in the history
* inference: fix tmerge lattice over issimpleenoughtype

Previously we assumed only union type could have complexity that
violated the tmerge lattice requirements, but other types can have that
too. This lets us fix an issue with the PartialStruct comparison failing
for undefined fields, mentioned in #43784.

* inference: refine PartialStruct lattice tmerge

Be more aggressive about merging fields to greatly accelerate
convergence, but also compute anyrefine more correctly as we do now
elsewhere (since #42831, a121721)

Move the tmeet algorithm, without changes, since it is a precise lattice
operation, not a heuristic limit like tmerge.

Close #43784
  • Loading branch information
vtjnash authored Mar 13, 2022
1 parent ceec252 commit ff88fa4
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 57 deletions.
86 changes: 85 additions & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ The non-strict partial order over the type inference lattice.
end
for i in 1:nfields(a.val)
# XXX: let's handle varargs later
isdefined(a.val, i) || return false
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
(Const(getfield(a.val, i)), b.fields[i]) || return false
end
return true
Expand Down Expand Up @@ -289,6 +289,48 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
return a b && b a
end

# compute typeintersect over the extended inference lattice,
# as precisely as we can,
# where v is in the extended lattice, and t is a Type.
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end

widenconst(c::AnyConditional) = Bool
widenconst((; val)::Const) = isa(val, Type) ? Type{val} : typeof(val)
widenconst(m::MaybeUndef) = widenconst(m.typ)
Expand Down Expand Up @@ -427,3 +469,45 @@ function stupdate1!(state::VarTable, change::StateUpdate)
end
return false
end

# compute typeintersect over the extended inference lattice,
# as precisely as we can,
# where v is in the extended lattice, and t is a Type.
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
132 changes: 81 additions & 51 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,23 +298,71 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac
union_count_abstract(@nospecialize(x)) = !isdispatchelem(x)

function issimpleenoughtype(@nospecialize t)
t = ignorelimited(t)
return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH &&
unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
end

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
function issimplertype(@nospecialize(typea), @nospecialize(typeb))
typea = ignorelimited(typea)
typeb = ignorelimited(typeb)
typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference
typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference
typea === typeb && return true
if typea isa PartialStruct
aty = widenconst(typea)
for i = 1:length(typea.fields)
ai = typea.fields[i]
bi = fieldtype(aty, i)
is_lattice_equal(ai, bi) && continue
tni = _typename(widenconst(ai))
if tni isa Const
bi = (tni.val::Core.TypeName).wrapper
is_lattice_equal(ai, bi) && continue
end
bi = getfield_tfunc(typeb, Const(i))
is_lattice_equal(ai, bi) && continue
# It is not enough for ai to be simpler than bi: it must exactly equal
# (for this, an invariant struct field, by contrast to
# type_more_complex above which handles covariant tuples).
return false
end
elseif typea isa Type
return issimpleenoughtype(typea)
# elseif typea isa Const # fall-through good
elseif typea isa Conditional # follow issubconditional query
typeb isa Const && return true
typeb isa Conditional || return false
is_same_conditionals(typea, typeb) || return false
issimplertype(typea.vtype, typeb.vtype) || return false
issimplertype(typea.elsetype, typeb.elsetype) || return false
elseif typea isa InterConditional # ibid
typeb isa Const && return true
typeb isa InterConditional || return false
is_same_conditionals(typea, typeb) || return false
issimplertype(typea.vtype, typeb.vtype) || return false
issimplertype(typea.elsetype, typeb.elsetype) || return false
elseif typea isa PartialOpaque
# TODO
end
return true
end

# pick a wider type that contains both typea and typeb,
# with some limits on how "large" it can get,
# but without losing too much precision in common cases
# and also trying to be mostly associative and commutative
function tmerge(@nospecialize(typea), @nospecialize(typeb))
typea === Union{} && return typeb
typeb === Union{} && return typea
typea === typeb && return typea

suba = typea typeb
suba && issimpleenoughtype(typeb) && return typeb
suba && issimplertype(typeb, typea) && return typeb
subb = typeb typea
suba && subb && return typea
subb && issimpleenoughtype(typea) && return typea
subb && issimplertype(typea, typeb) && return typea

# type-lattice for LimitedAccuracy wrapper
# the merge create a slightly narrower type than needed, but we can't
Expand Down Expand Up @@ -404,6 +452,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
aty = widenconst(typea)
bty = widenconst(typeb)
if aty === bty
# must have egal here, since we do not create PartialStruct for non-concrete types
typea_nfields = nfields_tfunc(typea)
typeb_nfields = nfields_tfunc(typeb)
isa(typea_nfields, Const) || return aty
Expand All @@ -412,18 +461,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
type_nfields === typeb_nfields.val::Int || return aty
type_nfields == 0 && return aty
fields = Vector{Any}(undef, type_nfields)
anyconst = false
anyrefine = false
for i = 1:type_nfields
ai = getfield_tfunc(typea, Const(i))
bi = getfield_tfunc(typeb, Const(i))
ity = tmerge(ai, bi)
if ai === Union{} || bi === Union{}
ity = widenconst(ity)
ft = fieldtype(aty, i)
if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft)
# Since ai===bi, the given type has no restrictions on complexity.
# and can be used to refine ft
tyi = ai
elseif is_lattice_equal(bi, ft)
tyi = bi
else
# Otherwise choose between using the fieldtype or some other simple merged type.
# The wrapper type never has restrictions on complexity,
# so try to use that to refine the estimated type too.
tni = _typename(widenconst(ai))
if tni isa Const && tni === _typename(widenconst(bi))
# A tmeet call may cause tyi to become complex, but since the inputs were
# strictly limited to being egal, this has no restrictions on complexity.
# (Otherwise, we would need to use <: and take the narrower one without
# intersection. See the similar comment in abstract_call_method.)
tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper)
else
# Since aty===bty, the fieldtype has no restrictions on complexity.
tyi = ft
end
end
fields[i] = tyi
if !anyrefine
anyrefine = has_nontrivial_const_info(tyi) || # constant information
tyi ft # just a type-level information, but more precise than the declared type
end
fields[i] = ity
anyconst |= has_nontrivial_const_info(ity)
end
return anyconst ? PartialStruct(aty, fields) : aty
return anyrefine ? PartialStruct(aty, fields) : aty
end
end
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
Expand Down Expand Up @@ -610,44 +681,3 @@ function tuplemerge(a::DataType, b::DataType)
end
return Tuple{p...}
end

# compute typeintersect over the extended inference lattice
# where v is in the extended lattice, and t is a Type
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
28 changes: 23 additions & 5 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3999,15 +3999,33 @@ end
@test (a, c)
@test (b, c)

@test @eval Module() begin
const ginit = Base.ImmutableDict{Any,Any}()
Base.return_types() do
g = ginit
init = Base.ImmutableDict{Number,Number}()
a = Const(init)
b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test (a, c) && (b, c)
@test c === typeof(init)

a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test (a, c) && (b, c)
@test c.fields[2] === Any # or Number
@test c.fields[3] === ComplexF64

b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
c = Core.Compiler.tmerge(a, b)
@test (a, c)
@test (b, c)
@test c.fields[2] === Complex
@test c.fields[3] === Complex

global const ginit43784 = Base.ImmutableDict{Any,Any}()
@test Base.return_types() do
g = ginit43784
while true
g = Base.ImmutableDict(g, 1=>2)
end
end |> only === Union{}
end
end

# Test that purity modeling doesn't accidentally introduce new world age issues
Expand Down

0 comments on commit ff88fa4

Please sign in to comment.