Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend SROA comparison lifting to Core.ifelse #49882

Merged
merged 1 commit into from
May 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 134 additions & 70 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,12 @@ function find_def_for_use(
return def, useblock, curblock
end

function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice)
function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice,
predecessors = ((@nospecialize(def), compact::IncrementalCompact) -> isa(def, PhiNode) ? def.values : nothing))
if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
return walk_to_defs(compact, val, typeconstraint, 𝕃ₒ)
return walk_to_defs(compact, val, typeconstraint, predecessors, 𝕃ₒ)
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
Expand Down Expand Up @@ -235,16 +236,21 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss
end

"""
walk_to_defs(compact, val, typeconstraint)
walk_to_defs(compact, val, typeconstraint, predecessors)

Starting at `val` walk use-def chains to get all the leaves feeding into this `val`
(pruning those leaves rules out by path conditions).
(pruning those leaves ruled out by path conditions).

`predecessors(def, compact)` is a callback which should return the set of possible
predecessors for a "phi-like" node (PhiNode or Core.ifelse) or `nothing` otherwise.
"""
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice)
visited_phinodes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes
function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), predecessors, 𝕃ₒ::AbstractLattice)
visited_philikes = AnySSAValue[]
isa(defssa, AnySSAValue) || return Any[defssa], visited_philikes
def = compact[defssa][:inst]
isa(def, PhiNode) || return Any[defssa], visited_phinodes
if predecessors(def, compact) === nothing
return Any[defssa], visited_philikes
end
visited_constraints = IdDict{AnySSAValue, Any}()
worklist_defs = AnySSAValue[]
worklist_constraints = Any[]
Expand All @@ -256,12 +262,14 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
typeconstraint = pop!(worklist_constraints)
visited_constraints[defssa] = typeconstraint
def = compact[defssa][:inst]
if isa(def, PhiNode)
push!(visited_phinodes, defssa)
values = predecessors(def, compact)
if values !== nothing
push!(visited_philikes, defssa)
possible_predecessors = Int[]
for n in 1:length(def.edges)
isassigned(def.values, n) || continue
val = def.values[n]

for n in 1:length(values)
isassigned(values, n) || continue
val = values[n]
if is_old(compact, defssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
Expand All @@ -270,8 +278,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
push!(possible_predecessors, n)
end
for n in possible_predecessors
pred = def.edges[n]
val = def.values[n]
val = values[n]
if is_old(compact, defssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
Expand Down Expand Up @@ -306,7 +313,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe
push!(leaves, defssa)
end
end
return leaves, visited_phinodes
return leaves, visited_philikes
end

function record_immutable_preserve!(new_preserves::Vector{Any}, def::Expr, compact::IncrementalCompact)
Expand Down Expand Up @@ -566,7 +573,13 @@ function lift_comparison_leaves!(@specialize(tfunc),
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
isa(typeconstraint, Union) || return # bail out if there won't be a good chance for lifting
leaves, visited_phinodes = collect_leaves(compact, val, typeconstraint, 𝕃ₒ)

predecessors = function (@nospecialize(def), compact::IncrementalCompact)
isa(def, PhiNode) && return def.values
is_known_call(def, Core.ifelse, compact) && return def.args[3:4]
return nothing
end
leaves, visited_philikes = collect_leaves(compact, val, typeconstraint, 𝕃ₒ, predecessors)
length(leaves) ≤ 1 && return # bail out if we don't have multiple leaves

# check if we can evaluate the comparison for each one of the leaves
Expand All @@ -586,32 +599,65 @@ function lift_comparison_leaves!(@specialize(tfunc),

# perform lifting
lifted_val = perform_lifting!(compact,
visited_phinodes, cmp, lifting_cache, Bool,
visited_philikes, cmp, lifting_cache, Bool,
lifted_leaves::LiftedLeaves, val, nothing)::LiftedValue

compact[idx] = lifted_val.val
end

struct LiftedPhi
struct IfElseCall
call::Expr
end

# An intermediate data structure used for lifting expressions through a
# "phi-like" instruction (either a PhiNode or a call to Core.ifelse)
struct LiftedPhilike
ssa::AnySSAValue
node::PhiNode
node::Union{PhiNode,IfElseCall}
need_argupdate::Bool
end

struct SkipToken end; const SKIP_TOKEN = SkipToken()

function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=::AnySSAValue=#), @nospecialize(old_value),
lifted_philikes::Vector{LiftedPhilike}, lifted_leaves::LiftedLeaves, reverse_mapping::IdDict{AnySSAValue, Int})
val = old_value
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
val = simple_walk(compact, val)
end
if val in keys(lifted_leaves)
lifted_val = lifted_leaves[val]
lifted_val === nothing && return UNDEF_TOKEN
val = lifted_val.val
if isa(val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
val = simple_walk(compact, val, callback)
end
return val
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
return lifted_philikes[reverse_mapping[val]].ssa
else
return SKIP_TOKEN # Probably ignored by path condition, skip this
end
end

function is_old(compact, @nospecialize(old_node_ssa))
isa(old_node_ssa, OldSSAValue) &&
!is_pending(compact, old_node_ssa) &&
!already_inserted(compact, old_node_ssa)
end

function perform_lifting!(compact::IncrementalCompact,
visited_phinodes::Vector{AnySSAValue}, @nospecialize(cache_key),
visited_philikes::Vector{AnySSAValue}, @nospecialize(cache_key),
lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue},
@nospecialize(result_t), lifted_leaves::LiftedLeaves, @nospecialize(stmt_val),
lazydomtree::Union{LazyDomtree,Nothing})
reverse_mapping = IdDict{AnySSAValue, Int}()
for id in 1:length(visited_phinodes)
reverse_mapping[visited_phinodes[id]] = id
for id in 1:length(visited_philikes)
reverse_mapping[visited_philikes[id]] = id
end

# Check if all the lifted leaves are the same
Expand All @@ -636,7 +682,7 @@ function perform_lifting!(compact::IncrementalCompact,
dominates_all = true
if lazydomtree !== nothing
domtree = get!(lazydomtree)
for item in visited_phinodes
for item in visited_philikes
if !dominates_ssa(compact, domtree, the_leaf_val, item)
dominates_all = false
break
Expand All @@ -649,64 +695,82 @@ function perform_lifting!(compact::IncrementalCompact,
end

# Insert PhiNodes
nphis = length(visited_phinodes)
lifted_phis = Vector{LiftedPhi}(undef, nphis)
for i = 1:nphis
item = visited_phinodes[i]
nphilikes = length(visited_philikes)
lifted_philikes = Vector{LiftedPhilike}(undef, nphilikes)
for i = 1:nphilikes
old_ssa = visited_philikes[i]
old_inst = compact[old_ssa]
old_node = old_inst[:inst]::Union{PhiNode,Expr}
# FIXME this cache is broken somehow
# ckey = Pair{AnySSAValue, Any}(item, cache_key)
# ckey = Pair{AnySSAValue, Any}(old_ssa, cache_key)
# cached = ckey in keys(lifting_cache)
cached = false
if cached
ssa = lifting_cache[ckey]
lifted_phis[i] = LiftedPhi(ssa, compact[ssa][:inst]::PhiNode, false)
if isa(old_node, PhiNode)
lifted_philikes[i] = LiftedPhilike(ssa, old_node, false)
else
lifted_philikes[i] = LiftedPhilike(ssa, IfElseCall(old_node), false)
end
continue
end
n = PhiNode()
ssa = insert_node!(compact, item, effect_free(NewInstruction(n, result_t)))
if isa(old_node, PhiNode)
new_node = PhiNode()
ssa = insert_node!(compact, old_ssa, effect_free(NewInstruction(new_node, result_t)))
lifted_philikes[i] = LiftedPhilike(ssa, new_node, true)
else
@assert is_known_call(old_node, Core.ifelse, compact)
ifelse_func, condition, then_result, else_result = old_node.args
if is_old(compact, old_ssa) && isa(condition, SSAValue)
condition = OldSSAValue(condition.id)
end

new_node = Expr(:call, ifelse_func, condition, then_result, else_result)
new_inst = NewInstruction(new_node, result_t, NoCallInfo(), old_inst[:line], old_inst[:flag])

ssa = insert_node!(compact, old_ssa, new_inst)
lifted_philikes[i] = LiftedPhilike(ssa, IfElseCall(new_node), true)
end
# lifting_cache[ckey] = ssa
lifted_phis[i] = LiftedPhi(ssa, n, true)
end

# Fix up arguments
for i = 1:nphis
(old_node_ssa, lf) = visited_phinodes[i], lifted_phis[i]
old_node = compact[old_node_ssa][:inst]::PhiNode
new_node = lf.node
should_count = !isa(lf.ssa, OldSSAValue) || already_inserted(compact, lf.ssa)
for i = 1:nphilikes
(old_node_ssa, lf) = visited_philikes[i], lifted_philikes[i]
lf.need_argupdate || continue
for i = 1:length(old_node.edges)
edge = old_node.edges[i]
isassigned(old_node.values, i) || continue
val = old_node.values[i]
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, AnySSAValue)
val = simple_walk(compact, val)
end
if val in keys(lifted_leaves)
push!(new_node.edges, edge)
lifted_val = lifted_leaves[val]
if lifted_val === nothing
should_count = !isa(lf.ssa, OldSSAValue) || already_inserted(compact, lf.ssa)

lfnode = lf.node
if isa(lfnode, PhiNode)
old_node = compact[old_node_ssa][:inst]::PhiNode
new_node = lfnode
for i = 1:length(old_node.values)
isassigned(old_node.values, i) || continue
val = lifted_value(compact, old_node_ssa, old_node.values[i],
lifted_philikes, lifted_leaves, reverse_mapping)
val !== SKIP_TOKEN && push!(new_node.edges, old_node.edges[i])
if val === UNDEF_TOKEN
resize!(new_node.values, length(new_node.values)+1)
continue
end
val = lifted_val.val
if isa(val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
val = simple_walk(compact, val, callback)
elseif val !== SKIP_TOKEN
should_count && _count_added_node!(compact, val)
push!(new_node.values, val)
end
should_count && _count_added_node!(compact, val)
push!(new_node.values, val)
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
push!(new_node.edges, edge)
newval = lifted_phis[reverse_mapping[val]].ssa
should_count && _count_added_node!(compact, newval)
push!(new_node.values, newval)
else
# Probably ignored by path condition, skip this
end
elseif isa(lfnode, IfElseCall)
then_result, else_result = lfnode.call.args[3], lfnode.call.args[4]

then_result = lifted_value(compact, old_node_ssa, then_result,
lifted_philikes, lifted_leaves, reverse_mapping)
else_result = lifted_value(compact, old_node_ssa, else_result,
lifted_philikes, lifted_leaves, reverse_mapping)

should_count && _count_added_node!(compact, then_result)
should_count && _count_added_node!(compact, else_result)

@assert then_result !== SKIP_TOKEN && then_result !== UNDEF_TOKEN
@assert else_result !== SKIP_TOKEN && else_result !== UNDEF_TOKEN
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't convinced myself yet that these asserts will always be true

Copy link
Sponsor Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As far as I understand, we should always have then_result and else_result within the keys of lifted_leaves, so lifted_value should never return SKIP_TOKEN here. UNDEF_TOKEN is only used for :new expression, so it does not matter.


lfnode.call.args[3], lfnode.call.args[4] = then_result, else_result
end
end

Expand All @@ -718,7 +782,7 @@ function perform_lifting!(compact::IncrementalCompact,
if stmt_val in keys(lifted_leaves)
return lifted_leaves[stmt_val]
elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping)
return LiftedValue(lifted_phis[reverse_mapping[stmt_val]].ssa)
return LiftedValue(lifted_philikes[reverse_mapping[stmt_val]].ssa)
end

return stmt_val # N.B. should never happen
Expand Down Expand Up @@ -1006,7 +1070,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
field = try_compute_fieldidx_stmt(compact, stmt, struct_typ)
field === nothing && continue

leaves, visited_phinodes = collect_leaves(compact, val, struct_typ, 𝕃ₒ)
leaves, visited_philikes = collect_leaves(compact, val, struct_typ, 𝕃ₒ)
isempty(leaves) && continue

result_t = argextype(SSAValue(idx), compact)
Expand All @@ -1019,7 +1083,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing)
end

val = perform_lifting!(compact,
visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val, lazydomtree)
visited_philikes, field, lifting_cache, result_t, lifted_leaves, val, lazydomtree)

# Insert the undef check if necessary
if any_undef && val === nothing
Expand Down
31 changes: 28 additions & 3 deletions test/compiler/irpasses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ end
# comparison lifting
# ==================

let # lifting `===`
let # lifting `===` through PhiNode
src = code_typed1((Bool,Int,)) do c, x
y = c ? x : nothing
y === nothing # => ϕ(false, true)
Expand All @@ -557,7 +557,15 @@ let # lifting `===`
end
end

let # lifting `isa`
let # lifting `===` through Core.ifelse
src = code_typed1((Bool,Int,)) do c, x
y = Core.ifelse(c, x, nothing)
y === nothing # => Core.ifelse(c, false, true)
end
@test count(iscall((src, ===)), src.code) == 0
end

let # lifting `isa` through PhiNode
src = code_typed1((Bool,Int,)) do c, x
y = c ? x : nothing
isa(y, Int) # => ϕ(true, false)
Expand All @@ -580,7 +588,16 @@ let # lifting `isa`
end
end

let # lifting `isdefined`
let # lifting `isa` through Core.ifelse
src = code_typed1((Bool,Int,)) do c, x
y = Core.ifelse(c, x, nothing)
isa(y, Int) # => Core.ifelse(c, true, false)
end
@test count(iscall((src, isa)), src.code) == 0
end


let # lifting `isdefined` through PhiNode
src = code_typed1((Bool,Some{Int},)) do c, x
y = c ? x : nothing
isdefined(y, 1) # => ϕ(true, false)
Expand All @@ -603,6 +620,14 @@ let # lifting `isdefined`
end
end

let # lifting `isdefined` through Core.ifelse
src = code_typed1((Bool,Some{Int},)) do c, x
y = Core.ifelse(c, x, nothing)
isdefined(y, 1) # => Core.ifelse(c, true, false)
end
@test count(iscall((src, isdefined)), src.code) == 0
end

mutable struct Foo30594; x::Float64; end
Base.copy(x::Foo30594) = Foo30594(x.x)
function add!(p::Foo30594, off::Foo30594)
Expand Down