Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference: refine branched Conditional types #55216

Merged
merged 1 commit into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 40 additions & 20 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3067,7 +3067,7 @@ end
@inline function abstract_eval_basic_statement(interp::AbstractInterpreter,
@nospecialize(stmt), pc_vartable::VarTable, frame::InferenceState)
if isa(stmt, NewvarNode)
changes = StateUpdate(stmt.slot, VarState(Bottom, true), pc_vartable, false)
changes = StateUpdate(stmt.slot, VarState(Bottom, true), false)
return BasicStmtChange(changes, nothing, Union{})
elseif !isa(stmt, Expr)
(; rt, exct) = abstract_eval_statement(interp, stmt, pc_vartable, frame)
Expand All @@ -3082,7 +3082,7 @@ end
end
lhs = stmt.args[1]
if isa(lhs, SlotNumber)
changes = StateUpdate(lhs, VarState(rt, false), pc_vartable, false)
changes = StateUpdate(lhs, VarState(rt, false), false)
elseif isa(lhs, GlobalRef)
handle_global_assignment!(interp, frame, lhs, rt)
elseif !isa(lhs, SSAValue)
Expand All @@ -3092,7 +3092,7 @@ end
elseif hd === :method
fname = stmt.args[1]
if isa(fname, SlotNumber)
changes = StateUpdate(fname, VarState(Any, false), pc_vartable, false)
changes = StateUpdate(fname, VarState(Any, false), false)
end
return BasicStmtChange(changes, nothing, Union{})
elseif (hd === :code_coverage_effect || (
Expand Down Expand Up @@ -3242,18 +3242,18 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
@goto branch
elseif isa(stmt, GotoIfNot)
condx = stmt.cond
condxslot = ssa_def_slot(condx, frame)
condslot = ssa_def_slot(condx, frame)
condt = abstract_eval_value(interp, condx, currstate, frame)
if condt === Bottom
ssavaluetypes[currpc] = Bottom
empty!(frame.pclimitations)
@goto find_next_bb
end
orig_condt = condt
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condxslot, SlotNumber)
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condslot, SlotNumber)
# if this non-`Conditional` object is a slot, we form and propagate
# the conditional constraint on it
condt = Conditional(condxslot, Const(true), Const(false))
condt = Conditional(condslot, Const(true), Const(false))
end
condval = maybe_extract_const_bool(condt)
nothrow = (condval !== nothing) || ⊑(𝕃ᵢ, orig_condt, Bool)
Expand Down Expand Up @@ -3299,21 +3299,31 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
# We continue with the true branch, but process the false
# branch here.
if isa(condt, Conditional)
else_change = conditional_change(𝕃ᵢ, currstate, condt.elsetype, condt.slot)
else_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#false)
if else_change !== nothing
false_vartable = stoverwrite1!(copy(currstate), else_change)
elsestate = copy(currstate)
stoverwrite1!(elsestate, else_change)
elseif condslot isa SlotNumber
elsestate = copy(currstate)
else
false_vartable = currstate
elsestate = currstate
end
changed = update_bbstate!(𝕃ᵢ, frame, falsebb, false_vartable)
then_change = conditional_change(𝕃ᵢ, currstate, condt.thentype, condt.slot)
if condslot isa SlotNumber # refine the type of this conditional object itself for this else branch
stoverwrite1!(elsestate, condition_object_change(currstate, condt, condslot, #=then_or_else=#false))
end
else_changed = update_bbstate!(𝕃ᵢ, frame, falsebb, elsestate)
then_change = conditional_change(𝕃ᵢ, currstate, condt, #=then_or_else=#true)
thenstate = currstate
if then_change !== nothing
stoverwrite1!(currstate, then_change)
stoverwrite1!(thenstate, then_change)
end
if condslot isa SlotNumber # refine the type of this conditional object itself for this then branch
stoverwrite1!(thenstate, condition_object_change(currstate, condt, condslot, #=then_or_else=#true))
end
else
changed = update_bbstate!(𝕃ᵢ, frame, falsebb, currstate)
else_changed = update_bbstate!(𝕃ᵢ, frame, falsebb, currstate)
end
if changed
if else_changed
handle_control_backedge!(interp, frame, currpc, stmt.dest)
push!(W, falsebb)
end
Expand Down Expand Up @@ -3412,13 +3422,14 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
nothing
end

function conditional_change(𝕃ᵢ::AbstractLattice, state::VarTable, @nospecialize(typ), slot::Int)
vtype = state[slot]
function conditional_change(𝕃ᵢ::AbstractLattice, currstate::VarTable, condt::Conditional, then_or_else::Bool)
vtype = currstate[condt.slot]
oldtyp = vtype.typ
if iskindtype(typ)
newtyp = then_or_else ? condt.thentype : condt.elsetype
if iskindtype(newtyp)
# this code path corresponds to the special handling for `isa(x, iskindtype)` check
# implemented within `abstract_call_builtin`
elseif ⊑(𝕃ᵢ, ignorelimited(typ), ignorelimited(oldtyp))
elseif ⊑(𝕃ᵢ, ignorelimited(newtyp), ignorelimited(oldtyp))
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`,
# the comparison is likely simple
Expand All @@ -3428,9 +3439,18 @@ function conditional_change(𝕃ᵢ::AbstractLattice, state::VarTable, @nospecia
if oldtyp isa LimitedAccuracy
# typ is better unlimited, but we may still need to compute the tmeet with the limit
# "causes" since we ignored those in the comparison
typ = tmerge(𝕃ᵢ, typ, LimitedAccuracy(Bottom, oldtyp.causes))
newtyp = tmerge(𝕃ᵢ, newtyp, LimitedAccuracy(Bottom, oldtyp.causes))
end
return StateUpdate(SlotNumber(slot), VarState(typ, vtype.undef), state, true)
return StateUpdate(SlotNumber(condt.slot), VarState(newtyp, vtype.undef), true)
end

function condition_object_change(currstate::VarTable, condt::Conditional,
condslot::SlotNumber, then_or_else::Bool)
vtype = currstate[slot_id(condslot)]
newcondt = Conditional(condt.slot,
then_or_else ? condt.thentype : Union{},
then_or_else ? Union{} : condt.elsetype)
return StateUpdate(condslot, VarState(newcondt, vtype.undef), false)
end

# make as much progress on `frame` as possible (by handling cycles)
Expand Down
23 changes: 0 additions & 23 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,6 @@ end
struct StateUpdate
var::SlotNumber
vtype::VarState
state::VarTable
conditional::Bool
end

Expand Down Expand Up @@ -724,28 +723,6 @@ function invalidate_slotwrapper(vt::VarState, changeid::Int, ignore_conditional:
return nothing
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::StateUpdate)
changed = false
changeid = slot_id(changes.var)
for i = 1:length(state)
if i == changeid
newtype = changes.vtype
else
newtype = changes.state[i]
end
invalidated = invalidate_slotwrapper(newtype, changeid, changes.conditional)
if invalidated !== nothing
newtype = invalidated
end
oldtype = state[i]
if schanged(lattice, newtype, oldtype)
state[i] = smerge(lattice, oldtype, newtype)
changed = true
end
end
return changed
end

function stupdate!(lattice::AbstractLattice, state::VarTable, changes::VarTable)
changed = false
for i = 1:length(state)
Expand Down
87 changes: 42 additions & 45 deletions test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2151,78 +2151,75 @@ end

@testset "branching on conditional object" begin
# simple
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
b = a === nothing
return b ? 0 : a # ::Int
end == Any[Int]
end == Int

# can use multiple times (as far as the subject of condition hasn't changed)
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
b = a === nothing
c = b ? 0 : a # c::Int
d = !b ? a : 0 # d::Int
return c, d # ::Tuple{Int,Int}
end == Any[Tuple{Int,Int}]
end == Tuple{Int,Int}

# should invalidate old constraint when the subject of condition has changed
@test Base.return_types((Union{Nothing,Int},)) do a
@test Base.infer_return_type((Union{Nothing,Int},)) do a
cond = a === nothing
r1 = cond ? 0 : a # r1::Int
a = 0
r2 = cond ? a : 1 # r2::Int, not r2::Union{Nothing,Int}
return r1, r2 # ::Tuple{Int,Int}
end == Any[Tuple{Int,Int}]
end == Tuple{Int,Int}
end

# https://github.com/JuliaLang/julia/issues/42090#issuecomment-911824851
# `PartialStruct` shouldn't wrap `Conditional`
let M = Module()
@eval M begin
struct BePartialStruct
val::Int
cond
end
end

rt = @eval M begin
Base.return_types((Union{Nothing,Int},)) do a
cond = a === nothing
obj = $(Expr(:new, M.BePartialStruct, 42, :cond))
r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional)
a = $(gensym(:anyvar))::Any
r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here)
return r1, r2 # ::Tuple{Union{Nothing,Int},Any}
end |> only
end
@test rt == Tuple{Union{Nothing,Int},Any}
struct BePartialStruct
val::Int
cond
end
@test Tuple{Union{Nothing,Int},Any} == @eval Base.infer_return_type((Union{Nothing,Int},)) do a
cond = a === nothing
obj = $(Expr(:new, BePartialStruct, 42, :cond))
r1 = getfield(obj, :cond) ? 0 : a # r1::Union{Nothing,Int}, not r1::Int (because PartialStruct doesn't wrap Conditional)
a = $(gensym(:anyvar))::Any
r2 = getfield(obj, :cond) ? a : nothing # r2::Any, not r2::Const(nothing) (we don't need to worry about constraint invalidation here)
return r1, r2 # ::Tuple{Union{Nothing,Int},Any}
end

# make sure we never form nested `Conditional` (https://github.com/JuliaLang/julia/issues/46207)
@test Base.return_types((Any,)) do a
@test Base.infer_return_type((Any,)) do a
c = isa(a, Integer)
42 === c ? :a : "b"
end |> only === String
@test Base.return_types((Any,)) do a
end == String
@test Base.infer_return_type((Any,)) do a
c = isa(a, Integer)
c === 42 ? :a : "b"
end |> only === String
end == String

@testset "conditional constraint propagation from non-`Conditional` object" begin
@test Base.return_types((Bool,)) do b
if b
return !b ? nothing : 1 # ::Int
else
return 0
end
end == Any[Int]

@test Base.return_types((Any,)) do b
if b
return b # ::Bool
else
return nothing
end
end == Any[Union{Bool,Nothing}]
function condition_object_update1(cond)
if cond # `cond` is known to be `Const(true)` within this branch
return !cond ? nothing : 1 # ::Int
else
return cond ? nothing : 1 # ::Int
end
end
function condition_object_update2(x)
cond = x isa Int
if cond # `cond` is known to be `Const(true)` within this branch
return !cond ? nothing : x # ::Int
else
return cond ? nothing : 1 # ::Int
end
end
@testset "state update for condition object" begin
# refine the type of condition object into constant boolean values on branching
@test Base.infer_return_type(condition_object_update1, (Bool,)) == Int
@test Base.infer_return_type(condition_object_update1, (Any,)) == Int
# refine even when their original type is `Conditional`
@test Base.infer_return_type(condition_object_update2, (Any,)) == Int
end

@testset "`from_interprocedural!`: translate inter-procedural information" begin
Expand Down