Skip to content

Commit

Permalink
optimizer: improve general code quality (JuliaLang#42357)
Browse files Browse the repository at this point in the history
- add more type signatures
- add more `@nospecialize` decls
- remove dead/debug code
- add some docs on SROA and ADCE passes
  • Loading branch information
aviatesk authored and LilithHafner committed Mar 8, 2022
1 parent a1bf852 commit d233dc5
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 32 deletions.
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
compact.result[old_result_idx][:inst]), (compact.idx, active_bb)
end

function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing)
function maybe_erase_unused!(extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, callback = x::SSAValue->nothing)
stmt = compact.result[idx][:inst]
stmt === nothing && return false
if compact_exprtype(compact, SSAValue(idx)) === Bottom
Expand Down
90 changes: 59 additions & 31 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), pi_callback=(pi, idx)->false)
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand All @@ -124,7 +125,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
def = compact[defssa]
if isa(def, PiNode)
if pi_callback(def, defssa)
if callback(def, defssa)
return defssa
end
def = def.val
Expand All @@ -135,7 +136,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
defssa = def
elseif isa(def, AnySSAValue)
pi_callback(def, defssa)
callback(def, defssa)
if isa(def, SSAValue)
is_old(compact, defssa) && (def = OldSSAValue(def.id))
end
Expand All @@ -148,12 +149,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defidx), @nospecialize(typeconstraint) = types(compact)[defidx])
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint) = types(compact)[defssa])
callback = function (@nospecialize(pi), @nospecialize(idx))
isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
end
return false
end
def = simple_walk(compact, defidx, callback)
def = simple_walk(compact, defssa, callback)
return Pair{Any, Any}(def, typeconstraint)
end

Expand Down Expand Up @@ -273,8 +277,10 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
return oc Core.OpaqueClosure
end

function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# For every leaf, the lifted value
lifted_leaves = IdDict{Any, Any}()
maybe_undef = false
Expand Down Expand Up @@ -396,13 +402,13 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
elseif isa(leaf, Union{Argument, Expr})
return nothing
end
!ismutable(leaf) || return nothing
ismutable(leaf) && return nothing
isdefined(leaf, field) || return nothing
val = getfield(leaf, field)
is_inlineable_constant(val) || return nothing
lifted_leaves[leaf_key] = RefValue{Any}(quoted(val))
end
lifted_leaves, maybe_undef
return lifted_leaves, maybe_undef
end

make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ)
Expand All @@ -415,13 +421,11 @@ function lift_comparison!(compact::IncrementalCompact, idx::Int,
typeconstraint = widenconst(c2)
val = stmt.args[3]
else
cmp = c2
cmp = c2::Const
typeconstraint = widenconst(c1)
val = stmt.args[2]
end

is_type_only = isdefined(typeof(cmp), :instance)

if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
Expand Down Expand Up @@ -497,7 +501,7 @@ function perform_lifting!(compact::IncrementalCompact,
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue})
if isa(val, AnySSAValue)
val = simple_walk(compact, val)
end
if val in keys(lifted_leaves)
Expand All @@ -508,11 +512,12 @@ function perform_lifting!(compact::IncrementalCompact,
continue
end
lifted_val = lifted_val.x
if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue})
lifted_val = simple_walk(compact, lifted_val, (pi, idx)->true)
if isa(lifted_val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
lifted_val = simple_walk(compact, lifted_val, callback)
end
push!(new_node.values, lifted_val)
elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping)
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
push!(new_node.edges, edge)
push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa)
else
Expand All @@ -532,14 +537,31 @@ function perform_lifting!(compact::IncrementalCompact,

if stmt_val in keys(lifted_leaves)
stmt_val = lifted_leaves[stmt_val]
elseif isa(stmt_val, Union{NewSSAValue, SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping)
elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping)
stmt_val = RefValue{Any}(lifted_phis[reverse_mapping[stmt_val]].ssa)
end

return stmt_val
end

assertion_counter = 0
"""
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
This pass is based on a local alias analysis that collects field information by def-use chain walking.
It looks for struct allocation sites ("definitions"), and `getfield` calls as well as
`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information,
then this pass will replace corresponding usages with lifted values.
`mutable struct`s require additional cares and need to be handled separately from immutables.
For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should
give up the lifting conservatively when there are any "intermediate usages" that may escape
the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as
its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function getfield_elim_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
insertions = Vector{Any}()
Expand All @@ -554,7 +576,6 @@ function getfield_elim_pass!(ir::IRCode)
result_t = compact_exprtype(compact, SSAValue(idx))
is_getfield = is_setfield = false
field_ordering = :unspecified
is_ccall = false
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
if is_known_call(stmt, setfield!, compact)
is_setfield = true
Expand Down Expand Up @@ -610,8 +631,8 @@ function getfield_elim_pass!(ir::IRCode)
old_preserves = stmt.args[(6+nccallargs):end]
for (pidx, preserved_arg) in enumerate(old_preserves)
isa(preserved_arg, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand Down Expand Up @@ -670,8 +691,8 @@ function getfield_elim_pass!(ir::IRCode)

if ismutabletype(struct_typ)
isa(def, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand All @@ -691,6 +712,8 @@ function getfield_elim_pass!(ir::IRCode)
continue
end

# perform SROA on immutable structs here on

if isa(def, Union{OldSSAValue, SSAValue})
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
end
Expand All @@ -703,7 +726,7 @@ function getfield_elim_pass!(ir::IRCode)
field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, stmt, result_t, field, leaves)
r = lift_leaves(compact, result_t, field, leaves)
r === nothing && continue
lifted_leaves, any_undef = r

Expand Down Expand Up @@ -736,14 +759,13 @@ function getfield_elim_pass!(ir::IRCode)
@assert val !== nothing
end

global assertion_counter
assertion_counter::Int += 1
# global assertion_counter
# assertion_counter::Int += 1
#insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
#continue
compact[idx] = val === nothing ? nothing : val.x
end


non_dce_finish!(compact)
# Copy the use count, `simple_dce!` may modify it and for our predicate
# below we need it consistent with the state of the IR here (after tracking
Expand Down Expand Up @@ -874,11 +896,12 @@ function getfield_elim_pass!(ir::IRCode)
end
ir
end
# assertion_counter = 0

function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int)
# return whether this made a change
if isa(compact.result[idx][:inst], PhiNode)
return maybe_erase_unused!(extra_worklist, compact, idx, val -> phi_uses[val.id] -= 1)
return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1)
else
return maybe_erase_unused!(extra_worklist, compact, idx)
end
Expand All @@ -893,7 +916,7 @@ function count_uses(@nospecialize(stmt), uses::Vector{Int})
end
end

function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
while !isempty(worklist)
Expand All @@ -909,6 +932,11 @@ function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::In
end
end

"""
adce_pass!(ir::IRCode) -> newir::IRCode
Aggressive Dead Code Elimination pass to eliminate code.
"""
function adce_pass!(ir::IRCode)
phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes))
all_phis = Int[]
Expand Down Expand Up @@ -940,7 +968,7 @@ function adce_pass!(ir::IRCode)
for phi in all_phis
# Save any phi cycles that have non-phi uses
if compact.used_ssas[phi] - phi_uses[phi] != 0
mark_phi_cycles(compact, safe_phis, phi)
mark_phi_cycles!(compact, safe_phis, phi)
end
end
for phi in all_phis
Expand Down

0 comments on commit d233dc5

Please sign in to comment.