From 7a4b120df620d6f403f9be77c55692759ec1bd58 Mon Sep 17 00:00:00 2001 From: Jameson Nash Date: Tue, 1 Mar 2022 22:44:52 -0500 Subject: [PATCH] inference: refine PartialStruct lattice tmerge Be more aggressive about merging fields to greatly accelerate convergence, but also compute anyrefine more correctly as we do now elsewhere (since #42831, a121721f975fc4105ed24ebd0ad1020d08d07a38) Move the tmeet algorithm, without changes, since it is a precise lattice operation, not a heuristic limit like tmerge. Close #43784 --- base/compiler/typelattice.jl | 42 +++++++++++++++++++ base/compiler/typelimits.jl | 78 ++++++++++++++---------------------- test/compiler/inference.jl | 28 ++++++++++--- 3 files changed, 95 insertions(+), 53 deletions(-) diff --git a/base/compiler/typelattice.jl b/base/compiler/typelattice.jl index fa669d9dade1b..21b51eb379f02 100644 --- a/base/compiler/typelattice.jl +++ b/base/compiler/typelattice.jl @@ -425,3 +425,45 @@ function stupdate1!(state::VarTable, change::StateUpdate) end return false end + +# compute typeintersect over the extended inference lattice, +# as precisely as we can, +# where v is in the extended lattice, and t is a Type. +function tmeet(@nospecialize(v), @nospecialize(t)) + if isa(v, Const) + if !has_free_typevars(t) && !isa(v.val, t) + return Bottom + end + return v + elseif isa(v, PartialStruct) + has_free_typevars(t) && return v + widev = widenconst(v) + if widev <: t + return v + end + ti = typeintersect(widev, t) + valid_as_lattice(ti) || return Bottom + @assert widev <: Tuple + new_fields = Vector{Any}(undef, length(v.fields)) + for i = 1:length(new_fields) + vfi = v.fields[i] + if isvarargtype(vfi) + new_fields[i] = vfi + else + new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) + if new_fields[i] === Bottom + return Bottom + end + end + end + return tuple_tfunc(new_fields) + elseif isa(v, Conditional) + if !(Bool <: t) + return Bottom + end + return v + end + ti = typeintersect(widenconst(v), t) + valid_as_lattice(ti) || return Bottom + return ti +end diff --git a/base/compiler/typelimits.jl b/base/compiler/typelimits.jl index f5846732e9fc8..d08d7a80491c5 100644 --- a/base/compiler/typelimits.jl +++ b/base/compiler/typelimits.jl @@ -449,6 +449,7 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) aty = widenconst(typea) bty = widenconst(typeb) if aty === bty + # must have egal here, since we do not create PartialStruct for non-concrete types typea_nfields = nfields_tfunc(typea) typeb_nfields = nfields_tfunc(typeb) isa(typea_nfields, Const) || return aty @@ -457,18 +458,40 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb)) type_nfields === typeb_nfields.val::Int || return aty type_nfields == 0 && return aty fields = Vector{Any}(undef, type_nfields) - anyconst = false + anyrefine = false for i = 1:type_nfields ai = getfield_tfunc(typea, Const(i)) bi = getfield_tfunc(typeb, Const(i)) - ity = tmerge(ai, bi) - if ai === Union{} || bi === Union{} - ity = widenconst(ity) + ft = fieldtype(aty, i) + if is_lattice_equal(ai, bi) || is_lattice_equal(ai, ft) + # Since ai===bi, the given type has no restrictions on complexity. + # and can be used to refine ft + tyi = ai + elseif is_lattice_equal(bi, ft) + tyi = bi + else + # Otherwise choose between using the fieldtype or some other simple merged type. + # The wrapper type never has restrictions on complexity, + # so try to use that to refine the estimated type too. + tni = _typename(widenconst(ai)) + if tni isa Const && tni === _typename(widenconst(bi)) + # A tmeet call may cause tyi to become complex, but since the inputs were + # strictly limited to being egal, this has no restrictions on complexity. + # (Otherwise, we would need to use <: and take the narrower one without + # intersection. See the similar comment in abstract_call_method.) + tyi = typeintersect(ft, (tni.val::Core.TypeName).wrapper) + else + # Since aty===bty, the fieldtype has no restrictions on complexity. + tyi = ft + end + end + fields[i] = tyi + if !anyrefine + anyrefine = has_nontrivial_const_info(tyi) || # constant information + tyi ⋤ ft # just a type-level information, but more precise than the declared type end - fields[i] = ity - anyconst |= has_nontrivial_const_info(ity) end - return anyconst ? PartialStruct(aty, fields) : aty + return anyrefine ? PartialStruct(aty, fields) : aty end end if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb) @@ -655,44 +678,3 @@ function tuplemerge(a::DataType, b::DataType) end return Tuple{p...} end - -# compute typeintersect over the extended inference lattice -# where v is in the extended lattice, and t is a Type -function tmeet(@nospecialize(v), @nospecialize(t)) - if isa(v, Const) - if !has_free_typevars(t) && !isa(v.val, t) - return Bottom - end - return v - elseif isa(v, PartialStruct) - has_free_typevars(t) && return v - widev = widenconst(v) - if widev <: t - return v - end - ti = typeintersect(widev, t) - valid_as_lattice(ti) || return Bottom - @assert widev <: Tuple - new_fields = Vector{Any}(undef, length(v.fields)) - for i = 1:length(new_fields) - vfi = v.fields[i] - if isvarargtype(vfi) - new_fields[i] = vfi - else - new_fields[i] = tmeet(vfi, widenconst(getfield_tfunc(t, Const(i)))) - if new_fields[i] === Bottom - return Bottom - end - end - end - return tuple_tfunc(new_fields) - elseif isa(v, Conditional) - if !(Bool <: t) - return Bottom - end - return v - end - ti = typeintersect(widenconst(v), t) - valid_as_lattice(ti) || return Bottom - return ti -end diff --git a/test/compiler/inference.jl b/test/compiler/inference.jl index 61058f9589f52..218e484b2beca 100644 --- a/test/compiler/inference.jl +++ b/test/compiler/inference.jl @@ -3999,15 +3999,33 @@ end @test ⊑(a, c) @test ⊑(b, c) - @test @eval Module() begin - const ginit = Base.ImmutableDict{Any,Any}() - Base.return_types() do - g = ginit + init = Base.ImmutableDict{Number,Number}() + a = Const(init) + b = Core.PartialStruct(typeof(init), Any[Const(init), Any, ComplexF64]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) && ⊑(b, c) + @test c === typeof(init) + + a = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF64, ComplexF64]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) && ⊑(b, c) + @test c.fields[2] === Any # or Number + @test c.fields[3] === ComplexF64 + + b = Core.PartialStruct(typeof(init), Any[Const(init), ComplexF32, Union{ComplexF32,ComplexF64}]) + c = Core.Compiler.tmerge(a, b) + @test ⊑(a, c) + @test ⊑(b, c) + @test c.fields[2] === Complex + @test c.fields[3] === Complex + + global const ginit43784 = Base.ImmutableDict{Any,Any}() + @test Base.return_types() do + g = ginit43784 while true g = Base.ImmutableDict(g, 1=>2) end end |> only === Union{} - end end # Test that purity modeling doesn't accidentally introduce new world age issues