Skip to content

Commit

Permalink
inference: more stronger state update from branching
Browse files Browse the repository at this point in the history
This change addresses 
JuliaLang#40832 (review)
  • Loading branch information
aviatesk committed May 22, 2021
1 parent 34498f2 commit 85eedcb
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 27 deletions.
59 changes: 33 additions & 26 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1395,7 +1395,7 @@ function abstract_eval_special_value(interp::AbstractInterpreter, @nospecialize(
elseif isa(e, SSAValue)
return abstract_eval_ssavalue(e::SSAValue, sv.src)
elseif isa(e, SlotNumber) || isa(e, Argument)
return (vtypes[slot_id(e)]::VarState).typ
return get_varstate(vtypes, slot_id(e)).typ
elseif isa(e, GlobalRef)
return abstract_eval_global(e.mod, e.name)
end
Expand Down Expand Up @@ -1713,11 +1713,6 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
empty!(frame.pclimitations)
break
end
if !(isa(condt, Const) || isa(condt, Conditional)) && isa(condx, SlotNumber)
# if this non-`Conditional` object is a slot, we form and propagate
# the conditional constraint on it
condt = Conditional(condx, Const(true), Const(false))
end
condval = maybe_extract_const_bool(condt)
l = stmt.dest::Int
if !isempty(frame.pclimitations)
Expand All @@ -1739,6 +1734,12 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
changes_else = conditional_changes(changes_else, condt.elsetype, condt.var)
changes = conditional_changes(changes, condt.vtype, condt.var)
end
if isa(condx, SlotNumber)
tfalse = isa(condt, Conditional) ? Conditional(condt.var, Bottom, condt.elsetype) : Const(false)
ttrue = isa(condt, Conditional) ? Conditional(condt.var, condt.vtype, Bottom) : Const(true)
changes_else = add_state_change!(changes_else, condx, tfalse, true)
changes = add_state_change!(changes, condx, ttrue, true)
end
newstate_else = stupdate!(states[l], changes_else)
if newstate_else !== nothing
# add else branch to active IP list
Expand Down Expand Up @@ -1861,7 +1862,7 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)

pc´ > n && break # can't proceed with the fast-path fall-through
frame.handler_at[pc´] = frame.cur_hand
changes = collect_state_updates!(changes, frame)
changes = collect_state_changes!(changes, frame)
newstate = stupdate!(states[pc´], changes)
if isa(stmt, GotoNode) && frame.pc´´ < pc´
# if we are processing a goto node anyways,
Expand All @@ -1888,41 +1889,47 @@ function typeinf_local(interp::AbstractInterpreter, frame::InferenceState)
end

function add_state_update!(slot::SlotNumber, @nospecialize(new), frame::InferenceState)
states = frame.stmt_types[frame.currpc]::VarTable
old = ((states[slot_id(slot)])::VarState).typ
state = frame.stmt_types[frame.currpc]::VarTable
old = get_varstate(state, slot).typ
if !(old new) # new ⋤ old
push!(frame.state_updates, (slot, new))
return true
end
return false
end

function collect_state_updates!(changes::StateUpdate, frame::InferenceState)
state_updates = frame.state_updates
function collect_state_changes!(changes::StateUpdate, frame::InferenceState, conditional::Bool = changes.conditional)
slots = BitSet(slot_id(var) for (var, _) in changes.updates)
while !isempty(state_updates)
var, typ = pop!(state_updates)
while !isempty(frame.state_updates)
var, typ = pop!(frame.state_updates)
slot_id(var) in slots && continue # effects of the statement (like assignment) should have the precedence
vtype = VarState(typ, (changes.state[slot_id(var)]::VarState).undef)
push!(changes.updates, (var, vtype))
changes = add_state_change!(changes, var, typ, conditional)
end
return changes
end

function collect_state_updates!(changes::VarTable, frame::InferenceState)
state_updates = frame.state_updates
isempty(state_updates) && return changes
updates = Tuple{SlotNumber,VarState}[]
while !isempty(state_updates)
var, typ = pop!(state_updates)
vtype = VarState(typ, (changes[slot_id(var)]::VarState).undef)
push!(updates, (var, vtype))
function collect_state_changes!(changes::VarTable, frame::InferenceState, conditional::Bool = false)
while !isempty(frame.state_updates)
var, typ = pop!(frame.state_updates)
changes = add_state_change!(changes, var, typ, conditional)
end
return StateUpdate(updates, changes, false)
return changes
end

function add_state_change!(changes::StateUpdate, var::SlotNumber, @nospecialize(typ), conditional::Bool)
@assert changes.conditional === conditional
vtype = VarState(typ, get_varstate(changes.state, var).undef)
push!(changes.updates, (var, vtype))
return changes
end

function add_state_change!(changes::VarTable, var::SlotNumber, @nospecialize(typ), conditional::Bool)
vtype = VarState(typ, get_varstate(changes, var).undef)
return StateUpdate([(var, vtype)], changes, conditional)
end

function conditional_changes(changes::VarTable, @nospecialize(typ), var::SlotNumber)
oldtyp = (changes[slot_id(var)]::VarState).typ
oldtyp = get_varstate(changes, var).typ
# approximate test for `typ ∩ oldtyp` being better than `oldtyp`
# since we probably formed these types with `typesubstract`, the comparison is likely simple
if ignorelimited(typ) ignorelimited(oldtyp)
Expand All @@ -1935,7 +1942,7 @@ end

function bool_rt_to_conditional(@nospecialize(rt), slottypes::Vector{Any}, state::VarTable, slot_id::Int)
old = slottypes[slot_id]
new = widenconditional((state[slot_id]::VarState).typ) # avoid nested conditional
new = widenconditional(get_varstate(state, slot_id).typ)
if new old && !(old new)
if isa(rt, Const)
val = rt.val
Expand Down
3 changes: 3 additions & 0 deletions base/compiler/typelattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ end

const VarTable = Array{Any,1}

get_varstate(state::VarTable, slot::SlotNumber) = get_varstate(state, slot_id(slot))
get_varstate(state::VarTable, slot::Int) = state[slot]::VarState

struct StateUpdate
updates::Vector{Tuple{SlotNumber,VarState}}
state::VarTable
Expand Down
14 changes: 13 additions & 1 deletion test/compiler/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1833,7 +1833,8 @@ end
end == Any[Tuple{Int,Int}]
end

@testset "conditional constraint propagation from non-`Conditional` object" begin
@testset "state update on branching" begin
# refine condition type into constant boolean value on branching
@test Base.return_types((Bool,)) do b
if b
return !b ? nothing : 1 # ::Int
Expand All @@ -1842,13 +1843,24 @@ end
end
end == Any[Int]

# even when the original type isn't boolean type
@test Base.return_types((Any,)) do b
if b
return b # ::Bool
else
return nothing
end
end == Any[Union{Bool,Nothing}]

# and it still propagate `Conditional` information
@test Base.return_types((Any,)) do a
b = isa(a, Int)
if b
return !b ? nothing : a # ::Int
else
return 0
end
end == Any[Int]
end

function f25579(g)
Expand Down

0 comments on commit 85eedcb

Please sign in to comment.