Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: refine PartialStruct lattice tmerge #44404

Merged
merged 2 commits into from
Mar 13, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ The non-strict partial order over the type inference lattice.
end
for i in 1:nfields(a.val)
# XXX: let's handle varargs later
isdefined(a.val, i) || return false
isdefined(a.val, i) || continue # since ∀ T Union{} ⊑ T
⊑(Const(getfield(a.val, i)), b.fields[i]) || return false
end
return true
Expand Down Expand Up @@ -289,6 +289,48 @@ function is_lattice_equal(@nospecialize(a), @nospecialize(b))
return a ⊑ b && b ⊑ a
end

# compute typeintersect over the extended inference lattice,
# as precisely as we can,
# where v is in the extended lattice, and t is a Type.
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end

widenconst(c::AnyConditional) = Bool
widenconst((; val)::Const) = isa(val, Type) ? Type{val} : typeof(val)
widenconst(m::MaybeUndef) = widenconst(m.typ)
Expand Down Expand Up @@ -427,3 +469,45 @@ function stupdate1!(state::VarTable, change::StateUpdate)
end
return false
end

# compute typeintersect over the extended inference lattice,
# as precisely as we can,
# where v is in the extended lattice, and t is a Type.
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
function tmeet(@nospecialize(v), @nospecialize(t))
Comment on lines +473 to +476
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Duplicated with the definition at L295.

if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
132 changes: 81 additions & 51 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,23 +298,71 @@ union_count_abstract(x::Union) = union_count_abstract(x.a) + union_count_abstrac
union_count_abstract(@nospecialize(x)) = !isdispatchelem(x)

function issimpleenoughtype(@nospecialize t)
t = ignorelimited(t)
return unionlen(t) + union_count_abstract(t) <= MAX_TYPEUNION_LENGTH &&
unioncomplexity(t) <= MAX_TYPEUNION_COMPLEXITY
end

# A simplified type_more_complex query over the extended lattice
# (assumes typeb ⊑ typea)
function issimplertype(@nospecialize(typea), @nospecialize(typeb))
typea = ignorelimited(typea)
typeb = ignorelimited(typeb)
typea isa MaybeUndef && (typea = typea.typ) # n.b. does not appear in inference
typeb isa MaybeUndef && (typeb = typeb.typ) # n.b. does not appear in inference
typea === typeb && return true
if typea isa PartialStruct
aty = widenconst(typea)
for i = 1:length(typea.fields)
ai = typea.fields[i]
bi = fieldtype(aty, i)
is_lattice_equal(ai, bi) && continue
tni = _typename(widenconst(ai))
if tni isa Const
bi = (tni.val::Core.TypeName).wrapper
is_lattice_equal(ai, bi) && continue
end
bi = getfield_tfunc(typeb, Const(i))
is_lattice_equal(ai, bi) && continue
# It is not enough for ai to be simpler than bi: it must exactly equal
# (for this, an invariant struct field, by contrast to
# type_more_complex above which handles covariant tuples).
return false
end
elseif typea isa Type
return issimpleenoughtype(typea)
# elseif typea isa Const # fall-through good
elseif typea isa Conditional # follow issubconditional query
typeb isa Const && return true
typeb isa Conditional || return false
is_same_conditionals(typea, typeb) || return false
issimplertype(typea.vtype, typeb.vtype) || return false
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
issimplertype(typea.elsetype, typeb.elsetype) || return false
elseif typea isa InterConditional # ibid
typeb isa Const && return true
typeb isa InterConditional || return false
is_same_conditionals(typea, typeb) || return false
issimplertype(typea.vtype, typeb.vtype) || return false
issimplertype(typea.elsetype, typeb.elsetype) || return false
elseif typea isa PartialOpaque
# TODO
end
return true
end

# pick a wider type that contains both typea and typeb,
# with some limits on how "large" it can get,
# but without losing too much precision in common cases
# and also trying to be mostly associative and commutative
function tmerge(@nospecialize(typea), @nospecialize(typeb))
typea === Union{} && return typeb
typeb === Union{} && return typea
typea === typeb && return typea

suba = typea ⊑ typeb
suba && issimpleenoughtype(typeb) && return typeb
suba && issimplertype(typeb, typea) && return typeb
subb = typeb ⊑ typea
suba && subb && return typea
subb && issimpleenoughtype(typea) && return typea
subb && issimplertype(typea, typeb) && return typea

# type-lattice for LimitedAccuracy wrapper
# the merge create a slightly narrower type than needed, but we can't
Expand Down Expand Up @@ -404,6 +452,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
aty = widenconst(typea)
bty = widenconst(typeb)
if aty === bty
# must have egal here, since we do not create PartialStruct for non-concrete types
typea_nfields = nfields_tfunc(typea)
typeb_nfields = nfields_tfunc(typeb)
isa(typea_nfields, Const) || return aty
Expand All @@ -412,18 +461,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
type_nfields === typeb_nfields.val::Int || return aty
type_nfields == 0 && return aty
fields = Vector{Any}(undef, type_nfields)
anyconst = false
anyrefine = false
for i = 1:type_nfields
ai = getfield_tfunc(typea, Const(i))
bi = getfield_tfunc(typeb, Const(i))
ity = tmerge(ai, bi)
if ai === Union{} || bi === Union{}
ity = widenconst(ity)
ft = fieldtype(aty, i)
if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft)
# Since ai===bi, the given type has no restrictions on complexity.
# and can be used to refine ft
tyi = ai
elseif is_lattice_equal(bi, ft)
tyi = bi
else
# Otherwise choose between using the fieldtype or some other simple merged type.
# The wrapper type never has restrictions on complexity,
# so try to use that to refine the estimated type too.
tni = _typename(widenconst(ai))
if tni isa Const && tni === _typename(widenconst(bi))
# A tmeet call may cause tyi to become complex, but since the inputs were
# strictly limited to being egal, this has no restrictions on complexity.
# (Otherwise, we would need to use <: and take the narrower one without
# intersection. See the similar comment in abstract_call_method.)
tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper)
else
# Since aty===bty, the fieldtype has no restrictions on complexity.
tyi = ft
end
end
fields[i] = tyi
if !anyrefine
anyrefine = has_nontrivial_const_info(tyi) || # constant information
tyi ⋤ ft # just a type-level information, but more precise than the declared type
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
end
fields[i] = ity
anyconst |= has_nontrivial_const_info(ity)
end
return anyconst ? PartialStruct(aty, fields) : aty
return anyrefine ? PartialStruct(aty, fields) : aty
end
end
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
Expand Down Expand Up @@ -610,44 +681,3 @@ function tuplemerge(a::DataType, b::DataType)
end
return Tuple{p...}
end

# compute typeintersect over the extended inference lattice
# where v is in the extended lattice, and t is a Type
function tmeet(@nospecialize(v), @nospecialize(t))
if isa(v, Const)
if !has_free_typevars(t) && !isa(v.val, t)
return Bottom
end
return v
elseif isa(v, PartialStruct)
has_free_typevars(t) && return v
widev = widenconst(v)
if widev <: t
return v
end
ti = typeintersect(widev, t)
valid_as_lattice(ti) || return Bottom
@assert widev <: Tuple
new_fields = Vector{Any}(undef, length(v.fields))
for i = 1:length(new_fields)
vfi = v.fields[i]
if isvarargtype(vfi)
new_fields[i] = vfi
else
new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i))))
if new_fields[i] === Bottom
return Bottom
end
end
end
return tuple_tfunc(new_fields)
elseif isa(v, Conditional)
if !(Bool <: t)
return Bottom
end
return v
end
ti = typeintersect(widenconst(v), t)
valid_as_lattice(ti) || return Bottom
return ti
end
28 changes: 23 additions & 5 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3999,15 +3999,33 @@ end
@test ⊑(a, c)
@test ⊑(b, c)

@test @eval Module() begin
const ginit = Base.ImmutableDict{Any,Any}()
Base.return_types() do
g = ginit
init = Base.ImmutableDict{Number,Number}()
a = Const(init)
b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test ⊑(a, c) && ⊑(b, c)
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
@test c === typeof(init)

a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64])
c = Core.Compiler.tmerge(a, b)
@test ⊑(a, c) && ⊑(b, c)
@test c.fields[2] === Any # or Number
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
@test c.fields[3] === ComplexF64

b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}])
c = Core.Compiler.tmerge(a, b)
vtjnash marked this conversation as resolved.
Show resolved Hide resolved
@test ⊑(a, c)
@test ⊑(b, c)
@test c.fields[2] === Complex
@test c.fields[3] === Complex

global const ginit43784 = Base.ImmutableDict{Any,Any}()
@test Base.return_types() do
g = ginit43784
while true
g = Base.ImmutableDict(g, 1=>2)
end
end |> only === Union{}
end
end

# Test that purity modeling doesn't accidentally introduce new world age issues
Expand Down