Skip to content

Commit

Permalink
Merge pull request #47051 from JuliaLang/avi/ircleanup
Browse files Browse the repository at this point in the history
optimizer: refactors on SSAIR
  • Loading branch information
aviatesk authored Oct 7, 2022
2 parents 5334fa8 + b794a5a commit 8d783ef
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 92 deletions.
2 changes: 1 addition & 1 deletion base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ function ir_inline_unionsplit!(compact::IncrementalCompact, idx::Int,
elseif isa(case, InvokeCase)
inst = Expr(:invoke, case.invoke, argexprs′...)
flag = flags_for_effects(case.effects)
val = insert_node_here!(compact, NewInstruction(inst, typ, NoCallInfo(), line, flag, true))
val = insert_node_here!(compact, NewInstruction(inst, typ, NoCallInfo(), line, flag))
else
case = case::ConstantCase
val = case.val
Expand Down
194 changes: 105 additions & 89 deletions base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ end
InstructionStream() = InstructionStream(0)
length(is::InstructionStream) = length(is.inst)
isempty(is::InstructionStream) = isempty(is.inst)
function add!(is::InstructionStream)
function add_new_idx!(is::InstructionStream)
ninst = length(is) + 1
resize!(is, ninst)
return ninst
Expand Down Expand Up @@ -236,7 +236,7 @@ struct Instruction
data::InstructionStream
idx::Int
end
Instruction(is::InstructionStream) = Instruction(is, add!(is))
Instruction(is::InstructionStream) = Instruction(is, add_new_idx!(is))

@inline function getindex(node::Instruction, fld::Symbol)
isdefined(node, fld) && return getfield(node, fld)
Expand Down Expand Up @@ -278,7 +278,7 @@ end
NewNodeStream(len::Int=0) = NewNodeStream(InstructionStream(len), fill(NewNodeInfo(0, false), len))
length(new::NewNodeStream) = length(new.stmts)
isempty(new::NewNodeStream) = isempty(new.stmts)
function add!(new::NewNodeStream, pos::Int, attach_after::Bool)
function add_inst!(new::NewNodeStream, pos::Int, attach_after::Bool)
push!(new.info, NewNodeInfo(pos, attach_after))
return Instruction(new.stmts)
end
Expand All @@ -288,34 +288,48 @@ struct NewInstruction
stmt::Any
type::Any
info::CallInfo
# If nothing, copy the line from previous statement
# in the insertion location
line::Union{Int32, Nothing}
flag::UInt8

## Insertion options

# The IR_FLAG_EFFECT_FREE flag has already been computed (or forced).
# Don't bother redoing so on insertion.
effect_free_computed::Bool
NewInstruction(@nospecialize(stmt), @nospecialize(type), @nospecialize(info::CallInfo),
line::Union{Int32, Nothing}, flag::UInt8, effect_free_computed::Bool) =
new(stmt, type, info, line, flag, effect_free_computed)
end
NewInstruction(@nospecialize(stmt), @nospecialize(type)) =
NewInstruction(stmt, type, nothing)
NewInstruction(@nospecialize(stmt), @nospecialize(type), line::Union{Nothing, Int32}) =
NewInstruction(stmt, type, NoCallInfo(), line, IR_FLAG_NULL, false)
NewInstruction(@nospecialize(stmt), meta::Instruction; line::Union{Int32, Nothing}=nothing) =
NewInstruction(stmt, meta[:type], meta[:info], line === nothing ? meta[:line] : line, meta[:flag], true)

effect_free(inst::NewInstruction) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag | IR_FLAG_EFFECT_FREE, true)
non_effect_free(inst::NewInstruction) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag & ~IR_FLAG_EFFECT_FREE, true)
with_flags(inst::NewInstruction, flags::UInt8) =
NewInstruction(inst.stmt, inst.type, inst.info, inst.line, inst.flag | flags, true)

line::Union{Int32,Nothing} # if nothing, copy the line from previous statement in the insertion location
flag::Union{UInt8,Nothing} # if nothing, IR flags will be recomputed on insertion
function NewInstruction(@nospecialize(stmt), @nospecialize(type), @nospecialize(info::CallInfo),
line::Union{Int32,Nothing}, flag::Union{UInt8,Nothing})
return new(stmt, type, info, line, flag)
end
end
function NewInstruction(@nospecialize(stmt), @nospecialize(type), line::Union{Int32,Nothing}=nothing)
return NewInstruction(stmt, type, NoCallInfo(), line, nothing)
end
@nospecialize
function NewInstruction(newinst::NewInstruction;
stmt::Any=newinst.stmt,
type::Any=newinst.type,
info::CallInfo=newinst.info,
line::Union{Int32,Nothing}=newinst.line,
flag::Union{UInt8,Nothing}=newinst.flag)
return NewInstruction(stmt, type, info, line, flag)
end
function NewInstruction(inst::Instruction;
stmt::Any=inst[:inst],
type::Any=inst[:type],
info::CallInfo=inst[:info],
line::Union{Int32,Nothing}=inst[:line],
flag::Union{UInt8,Nothing}=inst[:flag])
return NewInstruction(stmt, type, info, line, flag)
end
@specialize
effect_free(newinst::NewInstruction) = NewInstruction(newinst; flag=add_flag(newinst, IR_FLAG_EFFECT_FREE))
non_effect_free(newinst::NewInstruction) = NewInstruction(newinst; flag=sub_flag(newinst, IR_FLAG_EFFECT_FREE))
with_flags(newinst::NewInstruction, flags::UInt8) = NewInstruction(newinst; flag=add_flag(newinst, flags))
without_flags(newinst::NewInstruction, flags::UInt8) = NewInstruction(newinst; flag=sub_flag(newinst, flags))
function add_flag(newinst::NewInstruction, newflag::UInt8)
flag = newinst.flag
flag === nothing && return newflag
return flag | newflag
end
function sub_flag(newinst::NewInstruction, newflag::UInt8)
flag = newinst.flag
flag === nothing && return IR_FLAG_NULL
return flag & ~newflag
end

struct IRCode
stmts::InstructionStream
Expand All @@ -332,8 +346,7 @@ struct IRCode
function IRCode(ir::IRCode, stmts::InstructionStream, cfg::CFG, new_nodes::NewNodeStream)
return new(stmts, ir.argtypes, ir.sptypes, ir.linetable, cfg, new_nodes, ir.meta)
end
global copy
copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes),
global copy(ir::IRCode) = new(copy(ir.stmts), copy(ir.argtypes), copy(ir.sptypes),
copy(ir.linetable), copy(ir.cfg), copy(ir.new_nodes), copy(ir.meta))
end

Expand Down Expand Up @@ -513,26 +526,15 @@ scan_ssa_use!(@specialize(push!), used, @nospecialize(stmt)) = foreachssa(ssa::S
# Manually specialized copy of the above with push! === Compiler.push!
scan_ssa_use!(used::IdSet, @nospecialize(stmt)) = foreachssa(ssa::SSAValue -> push!(used, ssa.id), stmt)

function insert_node!(ir::IRCode, pos::SSAValue, inst::NewInstruction, attach_after::Bool=false)
node = add!(ir.new_nodes, pos.id, attach_after)
node[:line] = something(inst.line, ir[pos][:line])
flag = inst.flag
if !inst.effect_free_computed
(consistent, effect_free_and_nothrow, nothrow) = stmt_effect_flags(fallback_lattice, inst.stmt, inst.type, ir)
if consistent
flag |= IR_FLAG_CONSISTENT
end
if effect_free_and_nothrow
flag |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
elseif nothrow
flag |= IR_FLAG_NOTHROW
end
end
node[:inst], node[:type], node[:flag], node[:info] = inst.stmt, inst.type, flag, inst.info
function insert_node!(ir::IRCode, pos::SSAValue, newinst::NewInstruction, attach_after::Bool=false)
node = add_inst!(ir.new_nodes, pos.id, attach_after)
newline = something(newinst.line, ir[pos][:line])
newflag = recompute_inst_flag(newinst, ir)
node = inst_from_newinst!(node, newinst, newline, newflag)
return SSAValue(length(ir.stmts) + node.idx)
end
insert_node!(ir::IRCode, pos::Int, inst::NewInstruction, attach_after::Bool=false) =
insert_node!(ir, SSAValue(pos), inst, attach_after)
insert_node!(ir::IRCode, pos::Int, newinst::NewInstruction, attach_after::Bool=false) =
insert_node!(ir, SSAValue(pos), newinst, attach_after)

# For bootstrapping
function my_sortperm(v)
Expand Down Expand Up @@ -769,27 +771,54 @@ function count_added_node!(compact::IncrementalCompact, @nospecialize(v))
end

function add_pending!(compact::IncrementalCompact, pos::Int, attach_after::Bool)
node = add!(compact.pending_nodes, pos, attach_after)
node = add_inst!(compact.pending_nodes, pos, attach_after)
# TODO: switch this to `l = length(pending_nodes); splice!(pending_perm, searchsorted(pending_perm, l), l)`
push!(compact.pending_perm, length(compact.pending_nodes))
sort!(compact.pending_perm, DEFAULT_STABLE, Order.By(x->compact.pending_nodes.info[x].pos, Order.Forward))
return node
end

function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction, attach_after::Bool=false)
@assert inst.effect_free_computed
function inst_from_newinst!(node::Instruction, newinst::NewInstruction,
newline::Int32=newinst.line::Int32, newflag::UInt8=newinst.flag::UInt8)
node[:inst] = newinst.stmt
node[:type] = newinst.type
node[:info] = newinst.info
node[:line] = newline
node[:flag] = newflag
return node
end

function recompute_inst_flag(newinst::NewInstruction, src::Union{IRCode,IncrementalCompact})
flag = newinst.flag
flag !== nothing && return flag
flag = IR_FLAG_NULL
(consistent, effect_free_and_nothrow, nothrow) = stmt_effect_flags(
fallback_lattice, newinst.stmt, newinst.type, src)
if consistent
flag |= IR_FLAG_CONSISTENT
end
if effect_free_and_nothrow
flag |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
elseif nothrow
flag |= IR_FLAG_NOTHROW
end
return flag
end

function insert_node!(compact::IncrementalCompact, @nospecialize(before), newinst::NewInstruction, attach_after::Bool=false)
newflag = newinst.flag::UInt8
if isa(before, SSAValue)
if before.id < compact.result_idx
count_added_node!(compact, inst.stmt)
line = something(inst.line, compact.result[before.id][:line])
node = add!(compact.new_new_nodes, before.id, attach_after)
count_added_node!(compact, newinst.stmt)
newline = something(newinst.line, compact.result[before.id][:line])
node = add_inst!(compact.new_new_nodes, before.id, attach_after)
node = inst_from_newinst!(node, newinst, newline, newflag)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(-node.idx)
else
line = something(inst.line, compact.ir.stmts[before.id][:line])
newline = something(newinst.line, compact.ir.stmts[before.id][:line])
node = add_pending!(compact, before.id, attach_after)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
node = inst_from_newinst!(node, newinst, newline, newflag)
os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes))
push!(compact.ssa_rename, os)
push!(compact.used_ssas, 0)
Expand All @@ -799,21 +828,21 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
pos = before.id
if pos < compact.idx
renamed = compact.ssa_rename[pos]::AnySSAValue
count_added_node!(compact, inst.stmt)
line = something(inst.line, compact.result[renamed.id][:line])
node = add!(compact.new_new_nodes, renamed.id, attach_after)
count_added_node!(compact, newinst.stmt)
newline = something(newinst.line, compact.result[renamed.id][:line])
node = add_inst!(compact.new_new_nodes, renamed.id, attach_after)
node = inst_from_newinst!(node, newinst, newline, newflag)
push!(compact.new_new_used_ssas, 0)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
return NewSSAValue(-node.idx)
else
if pos > length(compact.ir.stmts)
#@assert attach_after
info = compact.pending_nodes.info[pos - length(compact.ir.stmts) - length(compact.ir.new_nodes)]
pos, attach_after = info.pos, info.attach_after
end
line = something(inst.line, compact.ir.stmts[pos][:line])
newline = something(newinst.line, compact.ir.stmts[pos][:line])
node = add_pending!(compact, pos, attach_after)
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, line, inst.flag
node = inst_from_newinst!(node, newinst, newline, newflag)
os = OldSSAValue(length(compact.ir.stmts) + length(compact.ir.new_nodes) + length(compact.pending_nodes))
push!(compact.ssa_rename, os)
push!(compact.used_ssas, 0)
Expand All @@ -822,18 +851,18 @@ function insert_node!(compact::IncrementalCompact, before, inst::NewInstruction,
elseif isa(before, NewSSAValue)
# TODO: This is incorrect and does not maintain ordering among the new nodes
before_entry = compact.new_new_nodes.info[-before.id]
line = something(inst.line, compact.new_new_nodes.stmts[-before.id][:line])
new_entry = add!(compact.new_new_nodes, before_entry.pos, attach_after)
new_entry[:inst], new_entry[:type], new_entry[:line], new_entry[:flag] = inst.stmt, inst.type, line, inst.flag
newline = something(newinst.line, compact.new_new_nodes.stmts[-before.id][:line])
new_entry = add_inst!(compact.new_new_nodes, before_entry.pos, attach_after)
new_entry = inst_from_newinst!(new_entry, newinst, newline, newflag)
push!(compact.new_new_used_ssas, 0)
return NewSSAValue(-new_entry.idx)
else
error("Unsupported")
end
end

function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, reverse_affinity::Bool=false)
@assert inst.line !== nothing
function insert_node_here!(compact::IncrementalCompact, newinst::NewInstruction, reverse_affinity::Bool=false)
newline = newinst.line::Int32
refinish = false
result_idx = compact.result_idx
if reverse_affinity &&
Expand All @@ -846,21 +875,9 @@ function insert_node_here!(compact::IncrementalCompact, inst::NewInstruction, re
@assert result_idx == length(compact.result) + 1
resize!(compact, result_idx)
end
flag = inst.flag
if !inst.effect_free_computed
(consistent, effect_free_and_nothrow, nothrow) = stmt_effect_flags(fallback_lattice, inst.stmt, inst.type, compact)
if consistent
flag |= IR_FLAG_CONSISTENT
end
if effect_free_and_nothrow
flag |= IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
elseif nothrow
flag |= IR_FLAG_NOTHROW
end
end
node = compact.result[result_idx]
node[:inst], node[:type], node[:line], node[:flag] = inst.stmt, inst.type, inst.line, flag
count_added_node!(compact, inst.stmt) && push!(compact.late_fixup, result_idx)
newflag = recompute_inst_flag(newinst, compact)
node = inst_from_newinst!(compact.result[result_idx], newinst, newline, newflag)
count_added_node!(compact, newinst.stmt) && push!(compact.late_fixup, result_idx)
compact.result_idx = result_idx + 1
inst = SSAValue(result_idx)
refinish && finish_current_bb!(compact, 0)
Expand Down Expand Up @@ -1568,7 +1585,6 @@ function fixup_phinode_values!(compact::IncrementalCompact, old_values::Vector{A
return (values, fixup)
end


function fixup_node(compact::IncrementalCompact, @nospecialize(stmt), reify_new_nodes::Bool)
if isa(stmt, PhiNode)
(node, needs_fixup) = fixup_phinode_values!(compact, stmt.values, reify_new_nodes)
Expand Down Expand Up @@ -1721,10 +1737,10 @@ abstract type Inserter; end
struct InsertHere <: Inserter
compact::IncrementalCompact
end
(i::InsertHere)(new_inst::NewInstruction) = insert_node_here!(i.compact, new_inst)
(i::InsertHere)(newinst::NewInstruction) = insert_node_here!(i.compact, newinst)

struct InsertBefore{T<:Union{IRCode, IncrementalCompact}} <: Inserter
src::T
pos::SSAValue
end
(i::InsertBefore)(new_inst::NewInstruction) = insert_node!(i.src, i.pos, new_inst)
(i::InsertBefore)(newinst::NewInstruction) = insert_node!(i.src, i.pos, newinst)
6 changes: 4 additions & 2 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1090,7 +1090,9 @@ function try_inline_finalizer!(ir::IRCode, argexprs::Vector{Any}, idx::Int,
ssa_rename[ssa.id]
end
stmt′ = ssa_substitute_op!(InsertBefore(ir, SSAValue(idx)), inst, stmt′, argexprs, mi.specTypes, mi.sparam_vals, sp_ssa, :default)
ssa_rename[idx′] = insert_node!(ir, idx, NewInstruction(stmt′, inst; line = inst[:line] + linetable_offset), attach_after)
ssa_rename[idx′] = insert_node!(ir, idx,
NewInstruction(inst; stmt=stmt′, line=inst[:line]+linetable_offset),
attach_after)
end

return true
Expand Down Expand Up @@ -1459,7 +1461,7 @@ function canonicalize_typeassert!(compact::IncrementalCompact, idx::Int, stmt::E
NewInstruction(
PiNode(stmt.args[2], compact.result[idx][:type]),
compact.result[idx][:type],
compact.result[idx][:line]), true)
compact.result[idx][:line]), #=reverse_affinity=#true)
compact.ssa_rename[compact.idx-1] = pi
end

Expand Down

0 comments on commit 8d783ef

Please sign in to comment.