From d233dc523038e07b3e2a0421d7620dba59ca1356 Mon Sep 17 00:00:00 2001 From: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> Date: Thu, 30 Sep 2021 02:32:45 +0900 Subject: [PATCH] optimizer: improve general code quality (#42357) - add more type signatures - add more `@nospecialize` decls - remove dead/debug code - add some docs on SROA and ADCE passes --- base/compiler/ssair/ir.jl | 2 +- base/compiler/ssair/passes.jl | 90 +++++++++++++++++++++++------------ 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/base/compiler/ssair/ir.jl b/base/compiler/ssair/ir.jl index d3d62c9b2dfaaf..6008526799ca2f 100644 --- a/base/compiler/ssair/ir.jl +++ b/base/compiler/ssair/ir.jl @@ -1316,7 +1316,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}= compact.result[old_result_idx][:inst]), (compact.idx, active_bb) end -function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing) +function maybe_erase_unused!(extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, callback = x::SSAValue->nothing) stmt = compact.result[idx][:inst] stmt === nothing && return false if compact_exprtype(compact, SSAValue(idx)) === Bottom diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 07901f8c2f0a2f..c61852125ca94c 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -110,7 +110,8 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I end end -function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), pi_callback=(pi, idx)->false) +function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), + callback = (@nospecialize(pi), @nospecialize(idx)) -> false) while true if isa(defssa, OldSSAValue) if already_inserted(compact, defssa) @@ -124,7 +125,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end def = compact[defssa] if isa(def, PiNode) - if pi_callback(def, defssa) + if callback(def, defssa) return defssa end def = def.val @@ -135,7 +136,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end defssa = def elseif isa(def, AnySSAValue) - pi_callback(def, defssa) + callback(def, defssa) if isa(def, SSAValue) is_old(compact, defssa) && (def = OldSSAValue(def.id)) end @@ -148,12 +149,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA end end -function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defidx), @nospecialize(typeconstraint) = types(compact)[defidx]) +function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), + @nospecialize(typeconstraint) = types(compact)[defssa]) callback = function (@nospecialize(pi), @nospecialize(idx)) - isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))) + if isa(pi, PiNode) + typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)) + end return false end - def = simple_walk(compact, defidx, callback) + def = simple_walk(compact, defssa, callback) return Pair{Any, Any}(def, typeconstraint) end @@ -273,8 +277,10 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact) return oc ⊑ Core.OpaqueClosure end -function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt), - @nospecialize(result_t), field::Int, leaves::Vector{Any}) +# try to compute lifted values that can replace `getfield(x, field)` call +# where `x` is an immutable struct that are defined at any of `leaves` +function lift_leaves(compact::IncrementalCompact, + @nospecialize(result_t), field::Int, leaves::Vector{Any}) # For every leaf, the lifted value lifted_leaves = IdDict{Any, Any}() maybe_undef = false @@ -396,13 +402,13 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt), elseif isa(leaf, Union{Argument, Expr}) return nothing end - !ismutable(leaf) || return nothing + ismutable(leaf) && return nothing isdefined(leaf, field) || return nothing val = getfield(leaf, field) is_inlineable_constant(val) || return nothing lifted_leaves[leaf_key] = RefValue{Any}(quoted(val)) end - lifted_leaves, maybe_undef + return lifted_leaves, maybe_undef end make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ) @@ -415,13 +421,11 @@ function lift_comparison!(compact::IncrementalCompact, idx::Int, typeconstraint = widenconst(c2) val = stmt.args[3] else - cmp = c2 + cmp = c2::Const typeconstraint = widenconst(c1) val = stmt.args[2] end - is_type_only = isdefined(typeof(cmp), :instance) - if isa(val, Union{OldSSAValue, SSAValue}) val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) end @@ -497,7 +501,7 @@ function perform_lifting!(compact::IncrementalCompact, if is_old(compact, old_node_ssa) && isa(val, SSAValue) val = OldSSAValue(val.id) end - if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) + if isa(val, AnySSAValue) val = simple_walk(compact, val) end if val in keys(lifted_leaves) @@ -508,11 +512,12 @@ function perform_lifting!(compact::IncrementalCompact, continue end lifted_val = lifted_val.x - if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue}) - lifted_val = simple_walk(compact, lifted_val, (pi, idx)->true) + if isa(lifted_val, AnySSAValue) + callback = (@nospecialize(pi), @nospecialize(idx)) -> true + lifted_val = simple_walk(compact, lifted_val, callback) end push!(new_node.values, lifted_val) - elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping) + elseif isa(val, AnySSAValue) && val in keys(reverse_mapping) push!(new_node.edges, edge) push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa) else @@ -532,14 +537,31 @@ function perform_lifting!(compact::IncrementalCompact, if stmt_val in keys(lifted_leaves) stmt_val = lifted_leaves[stmt_val] - elseif isa(stmt_val, Union{NewSSAValue, SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping) + elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping) stmt_val = RefValue{Any}(lifted_phis[reverse_mapping[stmt_val]].ssa) end return stmt_val end -assertion_counter = 0 +""" + getfield_elim_pass!(ir::IRCode) -> newir::IRCode + +`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization. + +This pass is based on a local alias analysis that collects field information by def-use chain walking. +It looks for struct allocation sites ("definitions"), and `getfield` calls as well as +`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information, +then this pass will replace corresponding usages with lifted values. +`mutable struct`s require additional cares and need to be handled separately from immutables. +For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should +give up the lifting conservatively when there are any "intermediate usages" that may escape +the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as +its argument). + +In a case when all usages are fully eliminated, `struct` allocation may also be erased as +a result of dead code elimination. +""" function getfield_elim_pass!(ir::IRCode) compact = IncrementalCompact(ir) insertions = Vector{Any}() @@ -554,7 +576,6 @@ function getfield_elim_pass!(ir::IRCode) result_t = compact_exprtype(compact, SSAValue(idx)) is_getfield = is_setfield = false field_ordering = :unspecified - is_ccall = false # Step 1: Check whether the statement we're looking at is a getfield/setfield! if is_known_call(stmt, setfield!, compact) is_setfield = true @@ -610,8 +631,8 @@ function getfield_elim_pass!(ir::IRCode) old_preserves = stmt.args[(6+nccallargs):end] for (pidx, preserved_arg) in enumerate(old_preserves) isa(preserved_arg, SSAValue) || continue - let intermediaries = IdSet() - callback = function(@nospecialize(pi), ssa::AnySSAValue) + let intermediaries = IdSet{Int}() + callback = function (@nospecialize(pi), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false end @@ -670,8 +691,8 @@ function getfield_elim_pass!(ir::IRCode) if ismutabletype(struct_typ) isa(def, SSAValue) || continue - let intermediaries = IdSet() - callback = function(@nospecialize(pi), ssa::AnySSAValue) + let intermediaries = IdSet{Int}() + callback = function (@nospecialize(pi), @nospecialize(ssa)) push!(intermediaries, ssa.id) return false end @@ -691,6 +712,8 @@ function getfield_elim_pass!(ir::IRCode) continue end + # perform SROA on immutable structs here on + if isa(def, Union{OldSSAValue, SSAValue}) def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint) end @@ -703,7 +726,7 @@ function getfield_elim_pass!(ir::IRCode) field = try_compute_fieldidx(struct_typ, field) field === nothing && continue - r = lift_leaves(compact, stmt, result_t, field, leaves) + r = lift_leaves(compact, result_t, field, leaves) r === nothing && continue lifted_leaves, any_undef = r @@ -736,14 +759,13 @@ function getfield_elim_pass!(ir::IRCode) @assert val !== nothing end - global assertion_counter - assertion_counter::Int += 1 + # global assertion_counter + # assertion_counter::Int += 1 #insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true) #continue compact[idx] = val === nothing ? nothing : val.x end - non_dce_finish!(compact) # Copy the use count, `simple_dce!` may modify it and for our predicate # below we need it consistent with the state of the IR here (after tracking @@ -874,11 +896,12 @@ function getfield_elim_pass!(ir::IRCode) end ir end +# assertion_counter = 0 function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int) # return whether this made a change if isa(compact.result[idx][:inst], PhiNode) - return maybe_erase_unused!(extra_worklist, compact, idx, val -> phi_uses[val.id] -= 1) + return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1) else return maybe_erase_unused!(extra_worklist, compact, idx) end @@ -893,7 +916,7 @@ function count_uses(@nospecialize(stmt), uses::Vector{Int}) end end -function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::Int) +function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int) worklist = Int[] push!(worklist, phi) while !isempty(worklist) @@ -909,6 +932,11 @@ function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::In end end +""" + adce_pass!(ir::IRCode) -> newir::IRCode + +Aggressive Dead Code Elimination pass to eliminate code. +""" function adce_pass!(ir::IRCode) phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes)) all_phis = Int[] @@ -940,7 +968,7 @@ function adce_pass!(ir::IRCode) for phi in all_phis # Save any phi cycles that have non-phi uses if compact.used_ssas[phi] - phi_uses[phi] != 0 - mark_phi_cycles(compact, safe_phis, phi) + mark_phi_cycles!(compact, safe_phis, phi) end end for phi in all_phis