diff --git a/base/compiler/ssair/passes.jl b/base/compiler/ssair/passes.jl index 4bfb5f3fcde56..44409cfbcd486 100644 --- a/base/compiler/ssair/passes.jl +++ b/base/compiler/ssair/passes.jl @@ -176,11 +176,12 @@ function find_def_for_use( return def, useblock, curblock end -function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice) +function collect_leaves(compact::IncrementalCompact, @nospecialize(val), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice, + predecessors = ((@nospecialize(def), compact::IncrementalCompact) -> isa(def, PhiNode) ? def.values : nothing)) if isa(val, Union{OldSSAValue, SSAValue}) val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) end - return walk_to_defs(compact, val, typeconstraint, 𝕃ₒ) + return walk_to_defs(compact, val, typeconstraint, predecessors, 𝕃ₒ) end function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), @@ -235,16 +236,21 @@ function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defss end """ - walk_to_defs(compact, val, typeconstraint) + walk_to_defs(compact, val, typeconstraint, predecessors) Starting at `val` walk use-def chains to get all the leaves feeding into this `val` -(pruning those leaves rules out by path conditions). +(pruning those leaves ruled out by path conditions). + +`predecessors(def, compact)` is a callback which should return the set of possible +predecessors for a "phi-like" node (PhiNode or Core.ifelse) or `nothing` otherwise. """ -function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), 𝕃ₒ::AbstractLattice) - visited_phinodes = AnySSAValue[] - isa(defssa, AnySSAValue) || return Any[defssa], visited_phinodes +function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospecialize(typeconstraint), predecessors, 𝕃ₒ::AbstractLattice) + visited_philikes = AnySSAValue[] + isa(defssa, AnySSAValue) || return Any[defssa], visited_philikes def = compact[defssa][:inst] - isa(def, PhiNode) || return Any[defssa], visited_phinodes + if predecessors(def, compact) === nothing + return Any[defssa], visited_philikes + end visited_constraints = IdDict{AnySSAValue, Any}() worklist_defs = AnySSAValue[] worklist_constraints = Any[] @@ -256,12 +262,14 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe typeconstraint = pop!(worklist_constraints) visited_constraints[defssa] = typeconstraint def = compact[defssa][:inst] - if isa(def, PhiNode) - push!(visited_phinodes, defssa) + values = predecessors(def, compact) + if values !== nothing + push!(visited_philikes, defssa) possible_predecessors = Int[] - for n in 1:length(def.edges) - isassigned(def.values, n) || continue - val = def.values[n] + + for n in 1:length(values) + isassigned(values, n) || continue + val = values[n] if is_old(compact, defssa) && isa(val, SSAValue) val = OldSSAValue(val.id) end @@ -270,8 +278,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe push!(possible_predecessors, n) end for n in possible_predecessors - pred = def.edges[n] - val = def.values[n] + val = values[n] if is_old(compact, defssa) && isa(val, SSAValue) val = OldSSAValue(val.id) end @@ -306,7 +313,7 @@ function walk_to_defs(compact::IncrementalCompact, @nospecialize(defssa), @nospe push!(leaves, defssa) end end - return leaves, visited_phinodes + return leaves, visited_philikes end function record_immutable_preserve!(new_preserves::Vector{Any}, def::Expr, compact::IncrementalCompact) @@ -566,7 +573,13 @@ function lift_comparison_leaves!(@specialize(tfunc), val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint) end isa(typeconstraint, Union) || return # bail out if there won't be a good chance for lifting - leaves, visited_phinodes = collect_leaves(compact, val, typeconstraint, 𝕃ₒ) + + predecessors = function (@nospecialize(def), compact::IncrementalCompact) + isa(def, PhiNode) && return def.values + is_known_call(def, Core.ifelse, compact) && return def.args[3:4] + return nothing + end + leaves, visited_philikes = collect_leaves(compact, val, typeconstraint, 𝕃ₒ, predecessors) length(leaves) ≤ 1 && return # bail out if we don't have multiple leaves # check if we can evaluate the comparison for each one of the leaves @@ -586,18 +599,51 @@ function lift_comparison_leaves!(@specialize(tfunc), # perform lifting lifted_val = perform_lifting!(compact, - visited_phinodes, cmp, lifting_cache, Bool, + visited_philikes, cmp, lifting_cache, Bool, lifted_leaves::LiftedLeaves, val, nothing)::LiftedValue compact[idx] = lifted_val.val end -struct LiftedPhi +struct IfElseCall + call::Expr +end + +# An intermediate data structure used for lifting expressions through a +# "phi-like" instruction (either a PhiNode or a call to Core.ifelse) +struct LiftedPhilike ssa::AnySSAValue - node::PhiNode + node::Union{PhiNode,IfElseCall} need_argupdate::Bool end +struct SkipToken end; const SKIP_TOKEN = SkipToken() + +function lifted_value(compact::IncrementalCompact, @nospecialize(old_node_ssa#=::AnySSAValue=#), @nospecialize(old_value), + lifted_philikes::Vector{LiftedPhilike}, lifted_leaves::LiftedLeaves, reverse_mapping::IdDict{AnySSAValue, Int}) + val = old_value + if is_old(compact, old_node_ssa) && isa(val, SSAValue) + val = OldSSAValue(val.id) + end + if isa(val, AnySSAValue) + val = simple_walk(compact, val) + end + if val in keys(lifted_leaves) + lifted_val = lifted_leaves[val] + lifted_val === nothing && return UNDEF_TOKEN + val = lifted_val.val + if isa(val, AnySSAValue) + callback = (@nospecialize(pi), @nospecialize(idx)) -> true + val = simple_walk(compact, val, callback) + end + return val + elseif isa(val, AnySSAValue) && val in keys(reverse_mapping) + return lifted_philikes[reverse_mapping[val]].ssa + else + return SKIP_TOKEN # Probably ignored by path condition, skip this + end +end + function is_old(compact, @nospecialize(old_node_ssa)) isa(old_node_ssa, OldSSAValue) && !is_pending(compact, old_node_ssa) && @@ -605,13 +651,13 @@ function is_old(compact, @nospecialize(old_node_ssa)) end function perform_lifting!(compact::IncrementalCompact, - visited_phinodes::Vector{AnySSAValue}, @nospecialize(cache_key), + visited_philikes::Vector{AnySSAValue}, @nospecialize(cache_key), lifting_cache::IdDict{Pair{AnySSAValue, Any}, AnySSAValue}, @nospecialize(result_t), lifted_leaves::LiftedLeaves, @nospecialize(stmt_val), lazydomtree::Union{LazyDomtree,Nothing}) reverse_mapping = IdDict{AnySSAValue, Int}() - for id in 1:length(visited_phinodes) - reverse_mapping[visited_phinodes[id]] = id + for id in 1:length(visited_philikes) + reverse_mapping[visited_philikes[id]] = id end # Check if all the lifted leaves are the same @@ -636,7 +682,7 @@ function perform_lifting!(compact::IncrementalCompact, dominates_all = true if lazydomtree !== nothing domtree = get!(lazydomtree) - for item in visited_phinodes + for item in visited_philikes if !dominates_ssa(compact, domtree, the_leaf_val, item) dominates_all = false break @@ -649,64 +695,82 @@ function perform_lifting!(compact::IncrementalCompact, end # Insert PhiNodes - nphis = length(visited_phinodes) - lifted_phis = Vector{LiftedPhi}(undef, nphis) - for i = 1:nphis - item = visited_phinodes[i] + nphilikes = length(visited_philikes) + lifted_philikes = Vector{LiftedPhilike}(undef, nphilikes) + for i = 1:nphilikes + old_ssa = visited_philikes[i] + old_inst = compact[old_ssa] + old_node = old_inst[:inst]::Union{PhiNode,Expr} # FIXME this cache is broken somehow - # ckey = Pair{AnySSAValue, Any}(item, cache_key) + # ckey = Pair{AnySSAValue, Any}(old_ssa, cache_key) # cached = ckey in keys(lifting_cache) cached = false if cached ssa = lifting_cache[ckey] - lifted_phis[i] = LiftedPhi(ssa, compact[ssa][:inst]::PhiNode, false) + if isa(old_node, PhiNode) + lifted_philikes[i] = LiftedPhilike(ssa, old_node, false) + else + lifted_philikes[i] = LiftedPhilike(ssa, IfElseCall(old_node), false) + end continue end - n = PhiNode() - ssa = insert_node!(compact, item, effect_free(NewInstruction(n, result_t))) + if isa(old_node, PhiNode) + new_node = PhiNode() + ssa = insert_node!(compact, old_ssa, effect_free(NewInstruction(new_node, result_t))) + lifted_philikes[i] = LiftedPhilike(ssa, new_node, true) + else + @assert is_known_call(old_node, Core.ifelse, compact) + ifelse_func, condition, then_result, else_result = old_node.args + if is_old(compact, old_ssa) && isa(condition, SSAValue) + condition = OldSSAValue(condition.id) + end + + new_node = Expr(:call, ifelse_func, condition, then_result, else_result) + new_inst = NewInstruction(new_node, result_t, NoCallInfo(), old_inst[:line], old_inst[:flag]) + + ssa = insert_node!(compact, old_ssa, new_inst) + lifted_philikes[i] = LiftedPhilike(ssa, IfElseCall(new_node), true) + end # lifting_cache[ckey] = ssa - lifted_phis[i] = LiftedPhi(ssa, n, true) end # Fix up arguments - for i = 1:nphis - (old_node_ssa, lf) = visited_phinodes[i], lifted_phis[i] - old_node = compact[old_node_ssa][:inst]::PhiNode - new_node = lf.node - should_count = !isa(lf.ssa, OldSSAValue) || already_inserted(compact, lf.ssa) + for i = 1:nphilikes + (old_node_ssa, lf) = visited_philikes[i], lifted_philikes[i] lf.need_argupdate || continue - for i = 1:length(old_node.edges) - edge = old_node.edges[i] - isassigned(old_node.values, i) || continue - val = old_node.values[i] - if is_old(compact, old_node_ssa) && isa(val, SSAValue) - val = OldSSAValue(val.id) - end - if isa(val, AnySSAValue) - val = simple_walk(compact, val) - end - if val in keys(lifted_leaves) - push!(new_node.edges, edge) - lifted_val = lifted_leaves[val] - if lifted_val === nothing + should_count = !isa(lf.ssa, OldSSAValue) || already_inserted(compact, lf.ssa) + + lfnode = lf.node + if isa(lfnode, PhiNode) + old_node = compact[old_node_ssa][:inst]::PhiNode + new_node = lfnode + for i = 1:length(old_node.values) + isassigned(old_node.values, i) || continue + val = lifted_value(compact, old_node_ssa, old_node.values[i], + lifted_philikes, lifted_leaves, reverse_mapping) + val !== SKIP_TOKEN && push!(new_node.edges, old_node.edges[i]) + if val === UNDEF_TOKEN resize!(new_node.values, length(new_node.values)+1) - continue - end - val = lifted_val.val - if isa(val, AnySSAValue) - callback = (@nospecialize(pi), @nospecialize(idx)) -> true - val = simple_walk(compact, val, callback) + elseif val !== SKIP_TOKEN + should_count && _count_added_node!(compact, val) + push!(new_node.values, val) end - should_count && _count_added_node!(compact, val) - push!(new_node.values, val) - elseif isa(val, AnySSAValue) && val in keys(reverse_mapping) - push!(new_node.edges, edge) - newval = lifted_phis[reverse_mapping[val]].ssa - should_count && _count_added_node!(compact, newval) - push!(new_node.values, newval) - else - # Probably ignored by path condition, skip this end + elseif isa(lfnode, IfElseCall) + then_result, else_result = lfnode.call.args[3], lfnode.call.args[4] + + then_result = lifted_value(compact, old_node_ssa, then_result, + lifted_philikes, lifted_leaves, reverse_mapping) + else_result = lifted_value(compact, old_node_ssa, else_result, + lifted_philikes, lifted_leaves, reverse_mapping) + + should_count && _count_added_node!(compact, then_result) + should_count && _count_added_node!(compact, else_result) + + @assert then_result !== SKIP_TOKEN && then_result !== UNDEF_TOKEN + @assert else_result !== SKIP_TOKEN && else_result !== UNDEF_TOKEN + + lfnode.call.args[3], lfnode.call.args[4] = then_result, else_result end end @@ -718,7 +782,7 @@ function perform_lifting!(compact::IncrementalCompact, if stmt_val in keys(lifted_leaves) return lifted_leaves[stmt_val] elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping) - return LiftedValue(lifted_phis[reverse_mapping[stmt_val]].ssa) + return LiftedValue(lifted_philikes[reverse_mapping[stmt_val]].ssa) end return stmt_val # N.B. should never happen @@ -1006,7 +1070,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) field = try_compute_fieldidx_stmt(compact, stmt, struct_typ) field === nothing && continue - leaves, visited_phinodes = collect_leaves(compact, val, struct_typ, 𝕃ₒ) + leaves, visited_philikes = collect_leaves(compact, val, struct_typ, 𝕃ₒ) isempty(leaves) && continue result_t = argextype(SSAValue(idx), compact) @@ -1019,7 +1083,7 @@ function sroa_pass!(ir::IRCode, inlining::Union{Nothing,InliningState}=nothing) end val = perform_lifting!(compact, - visited_phinodes, field, lifting_cache, result_t, lifted_leaves, val, lazydomtree) + visited_philikes, field, lifting_cache, result_t, lifted_leaves, val, lazydomtree) # Insert the undef check if necessary if any_undef && val === nothing diff --git a/test/compiler/irpasses.jl b/test/compiler/irpasses.jl index c704a8cf1c434..f3c74df884cad 100644 --- a/test/compiler/irpasses.jl +++ b/test/compiler/irpasses.jl @@ -537,7 +537,7 @@ end # comparison lifting # ================== -let # lifting `===` +let # lifting `===` through PhiNode src = code_typed1((Bool,Int,)) do c, x y = c ? x : nothing y === nothing # => ϕ(false, true) @@ -557,7 +557,15 @@ let # lifting `===` end end -let # lifting `isa` +let # lifting `===` through Core.ifelse + src = code_typed1((Bool,Int,)) do c, x + y = Core.ifelse(c, x, nothing) + y === nothing # => Core.ifelse(c, false, true) + end + @test count(iscall((src, ===)), src.code) == 0 +end + +let # lifting `isa` through PhiNode src = code_typed1((Bool,Int,)) do c, x y = c ? x : nothing isa(y, Int) # => ϕ(true, false) @@ -580,7 +588,16 @@ let # lifting `isa` end end -let # lifting `isdefined` +let # lifting `isa` through Core.ifelse + src = code_typed1((Bool,Int,)) do c, x + y = Core.ifelse(c, x, nothing) + isa(y, Int) # => Core.ifelse(c, true, false) + end + @test count(iscall((src, isa)), src.code) == 0 +end + + +let # lifting `isdefined` through PhiNode src = code_typed1((Bool,Some{Int},)) do c, x y = c ? x : nothing isdefined(y, 1) # => ϕ(true, false) @@ -603,6 +620,14 @@ let # lifting `isdefined` end end +let # lifting `isdefined` through Core.ifelse + src = code_typed1((Bool,Some{Int},)) do c, x + y = Core.ifelse(c, x, nothing) + isdefined(y, 1) # => Core.ifelse(c, true, false) + end + @test count(iscall((src, isdefined)), src.code) == 0 +end + mutable struct Foo30594; x::Float64; end Base.copy(x::Foo30594) = Foo30594(x.x) function add!(p::Foo30594, off::Foo30594)