diff --git a/base/compiler/ssair/inlining2.jl b/base/compiler/ssair/inlining2.jl index 3082b5cfb0c2b..bcb6afe4a033c 100644 --- a/base/compiler/ssair/inlining2.jl +++ b/base/compiler/ssair/inlining2.jl @@ -350,6 +350,61 @@ function ir_inline_item!(compact::IncrementalCompact, idx::Int, argexprs::Vector return_value end +# Constraints are generally small, so a linear search is the bets option +function find_constraint(val, constraints) + for i = 1:length(constraints) + if val === constraints[i][1] + return constraints[i][2] + end + end + return nothing +end + +# Performs minimal backwards inference to catch a couple of interesting, common cases +function minimal_backinf(compact, constraints, unconstrained_types, argexprs) + for i = 2:length(argexprs) + isa(argexprs[i], SSAValue) || continue + # Check if the argexpr is in the constraint list directly + c = find_constraint(argexprs[i], constraints) + if c !== nothing + unconstrained_types[i] = c + end + # For boolean values check for type predicates on any of the constraints + ut = unconstrained_types[i] + if ut === Bool + def = compact[argexprs[i]] + isa(def, Expr) || continue + if is_known_call(def, ===, compact) + v1, v2, = def.args[2:3] + c = find_constraint(v1, constraints) + if c !== nothing + refined = egal_tfunc(c, compact_exprtype(compact, v2)) + if !(ut ⊑ refined) + unconstrained_types[i] = refined + end + end + c = find_constraint(v2, constraints) + if c !== nothing + refined = egal_tfunc(compact_exprtype(compact, v1), c) + if !(ut ⊑ refined) + unconstrained_types[i] = refined + end + end + elseif is_known_call(def, isa, compact) + v = def.args[2] + c = find_constraint(v, constraints) + if c !== nothing + refined = isa_tfunc(c, compact_exprtype(compact, def.args[3])) + if !(ut ⊑ refined) + unconstrained_types[i] = refined + end + end + end + end + end + unconstrained_types +end + function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, argexprs::Vector{Any}, linetable::Vector{LineInfoNode}, item::UnionSplit, boundscheck::Symbol, todo_bbs::Vector{Tuple{Int, Int}}) @@ -383,11 +438,40 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int, insert_node_here!(compact, GotoIfNot(cond, next_cond_bb), Union{}, line) bb = next_cond_bb - 1 finish_current_bb!(compact) - # Insert Pi nodes here + if !isa(case, ConstantCase) + argexprs′ = copy(argexprs) + constraints = Pair{SSAValue, Any}[] + unconstrained_types = Any[atype.parameters...] + for i = 2:length(metharg.parameters) + a, m = unconstrained_types[i], metharg.parameters[i] + isa(argexprs[i], SSAValue) || continue + if !(a <: m) + push!(constraints, Pair{SSAValue, Any}(argexprs[i], m)) + end + end + constrained_types = minimal_backinf(compact, constraints, unconstrained_types, argexprs) + for i = 2:length(metharg.parameters) + if !(atype.parameters[i] ⊑ constrained_types[i]) + if isa(constrained_types[i], Const) + argexprs′[i] = constrained_types[i].val + else + ct = widenconst(constrained_types[i]) + if isa(ct, DataType) && isdefined(ct, :instance) + argexprs′[i] = ct.instance + else + argexprs′[i] = insert_node_here!(compact, PiNode(argexprs′[i], constrained_types[i]), + constrained_types[i], line) + end + end + end + end + else + argexprs′ = argexprs + end if isa(case, InliningTodo) - val = ir_inline_item!(compact, idx, argexprs, linetable, case, boundscheck, todo_bbs) + val = ir_inline_item!(compact, idx, argexprs′, linetable, case, boundscheck, todo_bbs) elseif isa(case, MethodInstance) - val = insert_node_here!(compact, Expr(:invoke, case, argexprs...), typ, line) + val = insert_node_here!(compact, Expr(:invoke, case, argexprs′...), typ, line) else case = case::ConstantCase val = case.val @@ -865,6 +949,7 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv:: # Now, if profitable union split the atypes into dispatch tuples and match the appropriate method nu = countunionsplit(atypes) if nu != 1 && nu <= sv.params.MAX_UNION_SPLITTING + fully_covered = true for sig in UnionSplitSignature(atypes) metharg′ = argtypes_to_type(sig) if !isdispatchtuple(metharg′) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index 97a202855a1f6..82956719b4e12 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -483,6 +483,11 @@ function insert_node_here!(compact::IncrementalCompact, @nospecialize(val), @nos ret end +function getindex(compact::IncrementalCompact, ssa::SSAValue) + @assert ssa.id < compact.result_idx + return compact.result[ssa.id] +end + function getindex(view::TypesView, v::OldSSAValue) return view.ir.ir.types[v.id] end @@ -523,6 +528,13 @@ function getindex(view::TypesView, idx) end end +function setindex!(view::TypesView, @nospecialize(t), idx) + isa(idx, SSAValue) && (idx = idx.id) + ir = view.ir + @assert isa(ir, IRCode) + ir.types[idx] = t +end + start(compact::IncrementalCompact) = (compact.idx, 1) function done(compact::IncrementalCompact, (idx, _a)::Tuple{Int, Int}) return idx > length(compact.ir.stmts) && (compact.new_nodes_idx > length(compact.perm)) diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 228eb87d0e205..df645f3fb76ee 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -524,3 +524,43 @@ function type_lift_pass!(ir::IRCode) end ir end + +function type_tightening_pass!(ir, sv) + compact = IncrementalCompact(ir) + phi_nodes = Int[] + for (idx, stmt) in compact + if isa(stmt, PhiNode) + isconcretetype(types(compact)[idx]) && continue + push!(phi_nodes, idx) + end + isa(stmt, Expr) || continue + isexpr(stmt, :call) || continue + isconcretetype(types(compact)[idx]) && continue + ft = compact_exprtype(compact, stmt.args[1]) + isa(ft, Const) || continue + # Let's be conservative in what we look at here for now + ft.val === select_value || continue + argtypes = Any[compact_exprtype(compact, stmt.args[i]) for i = 2:length(stmt.args)] + rt = builtin_tfunction(ft.val, argtypes, nothing, sv.params) + if !(compact.result_types[idx] ⊑ rt) + stmt.typ = rt + compact.result_types[idx] = rt + end + end + ir = finish(compact) + # Try to tigthen any phi nodes + for pn_idx in phi_nodes + new_typ = Union{} + pn = ir.stmts[pn_idx] + isa(pn, Nothing) && continue + pn = pn::PhiNode + for i = 1:length(pn.values) + isassigned(pn.values, i) || continue + new_typ = tmerge(new_typ, exprtype(pn.values[i], ir, ir.mod)) + end + if !(types(ir)[pn_idx] ⊑ new_typ) + types(ir)[pn_idx] = new_typ + end + end + ir +end diff --git a/base/compiler/ssair/show.jl b/base/compiler/ssair/show.jl index 26b734586e596..d5322dad9b259 100644 --- a/base/compiler/ssair/show.jl +++ b/base/compiler/ssair/show.jl @@ -88,7 +88,7 @@ function Base.show(io::IO, code::IRCode) end new_nodes = code.new_nodes[filter(i->isassigned(code.new_nodes, i), 1:length(code.new_nodes))] foreach(nn -> scan_ssa_use!(used, nn.node), new_nodes) - perm = sortperm(new_nodes, by = x->x[1]) + perm = sortperm(new_nodes, by = x->x.pos) new_nodes_perm = Iterators.Stateful(perm) if isempty(used) @@ -122,7 +122,7 @@ function Base.show(io::IO, code::IRCode) print_sep = true end floop = true - while !isempty(new_nodes_perm) && new_nodes[peek(new_nodes_perm)][1] == idx + while !isempty(new_nodes_perm) && new_nodes[peek(new_nodes_perm)].pos == idx node_idx = popfirst!(new_nodes_perm) new_node = new_nodes[node_idx] node_idx += length(code.stmts) diff --git a/base/compiler/tfuncs.jl b/base/compiler/tfuncs.jl index a291efdf61962..2a689ab8659a7 100644 --- a/base/compiler/tfuncs.jl +++ b/base/compiler/tfuncs.jl @@ -204,25 +204,25 @@ add_tfunc(Core.Intrinsics.select_value, 3, 3, (Bool ⊑ cnd) || return Bottom return tmerge(x, y) end, 1) -add_tfunc(===, 2, 2, - function (@nospecialize(x), @nospecialize(y)) - if isa(x, Const) && isa(y, Const) - return Const(x.val === y.val) - elseif typeintersect(widenconst(x), widenconst(y)) === Bottom - return Const(false) - elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || - (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) - return Const(true) - elseif isa(x, Conditional) && isa(y, Const) - y.val === false && return Conditional(x.var, x.elsetype, x.vtype) - y.val === true && return x - return x - elseif isa(y, Conditional) && isa(x, Const) - x.val === false && return Conditional(y.var, y.elsetype, y.vtype) - x.val === true && return y - end - return Bool - end, 1) +function egal_tfunc(@nospecialize(x), @nospecialize(y)) + if isa(x, Const) && isa(y, Const) + return Const(x.val === y.val) + elseif typeintersect(widenconst(x), widenconst(y)) === Bottom + return Const(false) + elseif (isa(x, Const) && y === typeof(x.val) && isdefined(y, :instance)) || + (isa(y, Const) && x === typeof(y.val) && isdefined(x, :instance)) + return Const(true) + elseif isa(x, Conditional) && isa(y, Const) + y.val === false && return Conditional(x.var, x.elsetype, x.vtype) + y.val === true && return x + return x + elseif isa(y, Conditional) && isa(x, Const) + x.val === false && return Conditional(y.var, y.elsetype, y.vtype) + x.val === true && return y + end + return Bool +end +add_tfunc(===, 2, 2, egal_tfunc, 1) function isdefined_tfunc(args...) arg1 = args[1] if isa(arg1, Const) @@ -381,29 +381,29 @@ add_tfunc(typeassert, 2, 2, end return typeintersect(v, t) end, 4) -add_tfunc(isa, 2, 2, - function (@nospecialize(v), @nospecialize(t)) - t, isexact = instanceof_tfunc(t) - if !has_free_typevars(t) - if t === Bottom - return Const(false) - elseif v ⊑ t - if isexact - return Const(true) - end - elseif isa(v, Const) || isa(v, Conditional) || isdispatchelem(v) - # this tests for knowledge of a leaftype appearing on the LHS - # (ensuring the isa is precise) - return Const(false) - elseif isexact && typeintersect(v, t) === Bottom - if !iskindtype(v) #= subtyping currently intentionally answers this query incorrectly for kinds =# - return Const(false) - end - end - end - # TODO: handle non-leaftype(t) by testing against lower and upper bounds - return Bool - end, 0) +function isa_tfunc(@nospecialize(v), @nospecialize(t)) + t, isexact = instanceof_tfunc(t) + if !has_free_typevars(t) + if t === Bottom + return Const(false) + elseif v ⊑ t + if isexact + return Const(true) + end + elseif isa(v, Const) || isa(v, Conditional) || isdispatchelem(v) + # this tests for knowledge of a leaftype appearing on the LHS + # (ensuring the isa is precise) + return Const(false) + elseif isexact && typeintersect(v, t) === Bottom + if !iskindtype(v) #= subtyping currently intentionally answers this query incorrectly for kinds =# + return Const(false) + end + end + end + # TODO: handle non-leaftype(t) by testing against lower and upper bounds + return Bool +end +add_tfunc(isa, 2, 2, isa_tfunc, 0) add_tfunc(<:, 2, 2, function (@nospecialize(a), @nospecialize(b)) a, isexact_a = instanceof_tfunc(a)