Skip to content

Commit

Permalink
[NewOptimizer] Better handling in the presence of select value
Browse files Browse the repository at this point in the history
The benchmarks contain code like this:
```
x::Union{Nothing, Int}
result += ifelse(x === nothing, 0, x)
```
which, perhaps somewhat ironically is quite a bit harder
on the new optimizer than an equivalent code sequence
using ternary operators. The reason for this is that
ifelse gets inferred as `Union{Int, Nothing}`, creating
a phi node of that type, which then causes a union split +
that the optimizer can't really get rid of easily. What this
commit does is add some local improvements to help with the
situation. First, it adds some minimal back inference during
inlining. As a result, when inlining decides to unionsplit
`ifelse(x === nothing, 0, x::Union{Nothing, Int})`, it looks
back at the definition of `x === nothing`, realizes it's constrained
by the union split and inserts the appropriate boolean constant.
Next, a new `type_tightening_pass` goes back and annotates more precise
types for the inlinined `select_value` and phi nodes. This is sufficient
to get the above code to behave reasonably and should hopefully fix
the performance regression on the various union sum benchmarks
seen in #26795.
  • Loading branch information
Keno committed May 3, 2018
1 parent 61350d6 commit 51b8f17
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 47 deletions.
91 changes: 88 additions & 3 deletions base/compiler/ssair/inlining2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,61 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector
return_value
end

# Constraints are generally small, so a linear search is the bets option
function find_constraint(val, constraints)
for i = 1:length(constraints)
if val === constraints[i][1]
return constraints[i][2]
end
end
return nothing
end

# Performs minimal backwards inference to catch a couple of interesting, common cases
function minimal_backinf(compact, constraints, unconstrained_types, argexprs)
for i = 2:length(argexprs)
isa(argexprs[i], SSAValue) || continue
# Check if the argexpr is in the constraint list directly
c = find_constraint(argexprs[i], constraints)
if c !== nothing
unconstrained_types[i] = c
end
# For boolean values check for type predicates on any of the constraints
ut = unconstrained_types[i]
if ut === Bool
def = compact[argexprs[i]]
isa(def, Expr) || continue
if is_known_call(def, ===, compact)
v1, v2, = def.args[2:3]
c = find_constraint(v1, constraints)
if c !== nothing
refined = egal_tfunc(c, compact_exprtype(compact, v2))
if !(ut refined)
unconstrained_types[i] = refined
end
end
c = find_constraint(v2, constraints)
if c !== nothing
refined = egal_tfunc(compact_exprtype(compact, v1), c)
if !(ut refined)
unconstrained_types[i] = refined
end
end
elseif is_known_call(def, isa, compact)
v = def.args[2]
c = find_constraint(v, constraints)
if c !== nothing
refined = isa_tfunc(c, compact_exprtype(compact, def.args[3]))
if !(ut refined)
unconstrained_types[i] = refined
end
end
end
end
end
unconstrained_types
end

function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
argexprs::Vector{Any}, linetable::Vector{LineInfoNode},
item::UnionSplit, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}})
Expand Down Expand Up @@ -379,11 +434,40 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
insert_node_here!(compact, GotoIfNot(cond, next_cond_bb), Union{}, line)
bb = next_cond_bb - 1
finish_current_bb!(compact)
# Insert Pi nodes here
if !isa(case, ConstantCase)
argexprs′ = copy(argexprs)
constraints = Pair{SSAValue, Any}[]
unconstrained_types = Any[atype.parameters...]
for i = 2:length(metharg.parameters)
a, m = unconstrained_types[i], metharg.parameters[i]
isa(argexprs[i], SSAValue) || continue
if !(a <: m)
push!(constraints, Pair{SSAValue, Any}(argexprs[i], m))
end
end
constrained_types = minimal_backinf(compact, constraints, unconstrained_types, argexprs)
for i = 2:length(metharg.parameters)
if !(atype.parameters[i] constrained_types[i])
if isa(constrained_types[i], Const)
argexprs′[i] = constrained_types[i].val
else
ct = widenconst(constrained_types[i])
if isa(ct, DataType) && isdefined(ct, :instance)
argexprs′[i] = ct.instance
else
argexprs′[i] = insert_node_here!(compact, PiNode(argexprs′[i], constrained_types[i]),
constrained_types[i], line)
end
end
end
end
else
argexprs′ = argexprs
end
if isa(case, InliningTodo)
val = ir_inline_item!(compact, idx, argexprs, linetable, case, boundscheck, todo_bbs)
val = ir_inline_item!(compact, idx, argexprs, linetable, case, boundscheck, todo_bbs)
elseif isa(case, MethodInstance)
val = insert_node_here!(compact, Expr(:invoke, case, argexprs...), typ, line)
val = insert_node_here!(compact, Expr(:invoke, case, argexprs...), typ, line)
else
case = case::ConstantCase
val = case.val
Expand Down Expand Up @@ -861,6 +945,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
# Now, if profitable union split the atypes into dispatch tuples and match the appropriate method
nu = countunionsplit(atypes)
if nu != 1 && nu <= sv.params.MAX_UNION_SPLITTING
fully_covered = true
for sig in UnionSplitSignature(atypes)
metharg′ = argtypes_to_type(sig)
if !isdispatchtuple(metharg′)
Expand Down
12 changes: 12 additions & 0 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,11 @@ function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nos
ret
end

function getindex(compact::IncrementalCompact, ssa::SSAValue)
@assert ssa.id < compact.result_idx
return compact.result[ssa.id]
end

function getindex(view::TypesView, v::OldSSAValue)
return view.ir.ir.types[v.id]
end
Expand Down Expand Up @@ -523,6 +528,13 @@ function getindex(view::TypesView, idx)
end
end

function setindex!(view::TypesView, @nospecialize(t), idx)
isa(idx, SSAValue) && (idx = idx.id)
ir = view.ir
@assert isa(ir, IRCode)
ir.types[idx] = t
end

start(compact::IncrementalCompact) = (compact.idx, 1)
function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int})
return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm))
Expand Down
40 changes: 40 additions & 0 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -524,3 +524,43 @@ function type_lift_pass!(ir::IRCode)
end
ir
end

function type_tightening_pass!(ir, sv)
compact = IncrementalCompact(ir)
phi_nodes = Int[]
for (idx, stmt) in compact
if isa(stmt, PhiNode)
isconcretetype(types(compact)[idx]) && continue
push!(phi_nodes, idx)
end
isa(stmt, Expr) || continue
isexpr(stmt, :call) || continue
isconcretetype(types(compact)[idx]) && continue
ft = compact_exprtype(compact, stmt.args[1])
isa(ft, Const) || continue
# Let's be conservative in what we look at here for now
ft.val === select_value || continue
argtypes = Any[compact_exprtype(compact, stmt.args[i]) for i = 2:length(stmt.args)]
rt = builtin_tfunction(ft.val, argtypes, nothing, sv.params)
if !(compact.result_types[idx] rt)
stmt.typ = rt
compact.result_types[idx] = rt
end
end
ir = finish(compact)
# Try to tigthen any phi nodes
for pn_idx in phi_nodes
new_typ = Union{}
pn = ir.stmts[pn_idx]
isa(pn, Nothing) && continue
pn = pn::PhiNode
for i = 1:length(pn.values)
isassigned(pn.values, i) || continue
new_typ = tmerge(new_typ, exprtype(pn.values[i], ir, ir.mod))
end
if !(types(ir)[pn_idx] new_typ)
types(ir)[pn_idx] = new_typ
end
end
ir
end
4 changes: 2 additions & 2 deletions base/compiler/ssair/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ function Base.show(io::IO, code::IRCode)
end
new_nodes = code.new_nodes[filter(i->isassigned(code.new_nodes, i), 1:length(code.new_nodes))]
foreach(nn -> scan_ssa_use!(used, nn.node), new_nodes)
perm = sortperm(new_nodes, by = x->x[1])
perm = sortperm(new_nodes, by = x->x.pos)
new_nodes_perm = Iterators.Stateful(perm)

if isempty(used)
Expand Down Expand Up @@ -122,7 +122,7 @@ function Base.show(io::IO, code::IRCode)
print_sep = true
end
floop = true
while !isempty(new_nodes_perm) && new_nodes[peek(new_nodes_perm)][1] == idx
while !isempty(new_nodes_perm) && new_nodes[peek(new_nodes_perm)].pos == idx
node_idx = popfirst!(new_nodes_perm)
new_node = new_nodes[node_idx]
node_idx += length(code.stmts)
Expand Down
84 changes: 42 additions & 42 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,25 +204,25 @@ add_tfunc(Core.Intrinsics.select_value, 3, 3,
(Bool cnd) || return Bottom
return tmerge(x, y)
end, 1)
add_tfunc(===, 2, 2,
function (@nospecialize(x), @nospecialize(y))
if isa(x, Const) && isa(y, Const)
return Const(x.val === y.val)
elseif typeintersect(widenconst(x), widenconst(y)) === Bottom
return Const(false)
elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) ||
(isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance))
return Const(true)
elseif isa(x, Conditional) && isa(y, Const)
y.val === false && return Conditional(x.var, x.elsetype, x.vtype)
y.val === true && return x
return x
elseif isa(y, Conditional) && isa(x, Const)
x.val === false && return Conditional(y.var, y.elsetype, y.vtype)
x.val === true && return y
end
return Bool
end, 1)
function egal_tfunc(@nospecialize(x), @nospecialize(y))
if isa(x, Const) && isa(y, Const)
return Const(x.val === y.val)
elseif typeintersect(widenconst(x), widenconst(y)) === Bottom
return Const(false)
elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) ||
(isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance))
return Const(true)
elseif isa(x, Conditional) && isa(y, Const)
y.val === false && return Conditional(x.var, x.elsetype, x.vtype)
y.val === true && return x
return x
elseif isa(y, Conditional) && isa(x, Const)
x.val === false && return Conditional(y.var, y.elsetype, y.vtype)
x.val === true && return y
end
return Bool
end
add_tfunc(===, 2, 2, egal_tfunc, 1)
function isdefined_tfunc(args...)
arg1 = args[1]
if isa(arg1, Const)
Expand Down Expand Up @@ -381,29 +381,29 @@ add_tfunc(typeassert, 2, 2,
end
return typeintersect(v, t)
end, 4)
add_tfunc(isa, 2, 2,
function (@nospecialize(v), @nospecialize(t))
t, isexact = instanceof_tfunc(t)
if !has_free_typevars(t)
if t === Bottom
return Const(false)
elseif v t
if isexact
return Const(true)
end
elseif isa(v, Const) || isa(v, Conditional) || isdispatchelem(v)
# this tests for knowledge of a leaftype appearing on the LHS
# (ensuring the isa is precise)
return Const(false)
elseif isexact && typeintersect(v, t) === Bottom
if !iskindtype(v) #= subtyping currently intentionally answers this query incorrectly for kinds =#
return Const(false)
end
end
end
# TODO: handle non-leaftype(t) by testing against lower and upper bounds
return Bool
end, 0)
function isa_tfunc(@nospecialize(v), @nospecialize(t))
t, isexact = instanceof_tfunc(t)
if !has_free_typevars(t)
if t === Bottom
return Const(false)
elseif v t
if isexact
return Const(true)
end
elseif isa(v, Const) || isa(v, Conditional) || isdispatchelem(v)
# this tests for knowledge of a leaftype appearing on the LHS
# (ensuring the isa is precise)
return Const(false)
elseif isexact && typeintersect(v, t) === Bottom
if !iskindtype(v) #= subtyping currently intentionally answers this query incorrectly for kinds =#
return Const(false)
end
end
end
# TODO: handle non-leaftype(t) by testing against lower and upper bounds
return Bool
end
add_tfunc(isa, 2, 2, isa_tfunc, 0)
add_tfunc(<:, 2, 2,
function (@nospecialize(a), @nospecialize(b))
a, isexact_a = instanceof_tfunc(a)
Expand Down

0 comments on commit 51b8f17

Please sign in to comment.