Skip to content

Commit

Permalink
Merge pull request #21 from N5N3/N5N3-patch-3-1
Browse files Browse the repository at this point in the history
N5 n3 patch 3 1
  • Loading branch information
N5N3 authored Sep 30, 2021
2 parents 8d28785 + 8e80036 commit d2914a9
Show file tree
Hide file tree
Showing 19 changed files with 255 additions and 143 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ Standard library changes
constructing the range. ([#40382])
* TCP socket objects now expose `closewrite` functionality and support half-open mode usage ([#40783]).
* Intersect returns a result with the eltype of the type-promoted eltypes of the two inputs ([#41769]).
* `Iterators.countfrom` now accepts any type that defines `+`. ([#37747])

#### InteractiveUtils
* A new macro `@time_imports` for reporting any time spent importing packages and their dependencies ([#41612])
Expand Down
17 changes: 10 additions & 7 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1087,10 +1087,12 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::LinRange) = LinRange(-r.star
# For #18336 we need to prevent promotion of the step type:
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractRange, x::Number) = range(first(r) + x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::AbstractRange) = range(x + first(r), step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange, x::Real) = range(first(r) + x, last(r) + x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::Real) = range(x + first(r), x + last(r), step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, last(r) + x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), x + last(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::OrdinalRange, x::Integer) = range(first(r) + x, last(r) + x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Integer, r::OrdinalRange) = range(x + first(r), x + last(r), step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Integer) = range(first(r) + x, last(r) + x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Integer, r::AbstractUnitRange) = range(x + first(r), x + last(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::AbstractUnitRange, x::Real) = range(first(r) + x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Real, r::AbstractUnitRange) = range(x + first(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)+x)}(r.ref + x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(+), x::Number, r::StepRangeLen{T}) where T =
Expand All @@ -1101,9 +1103,10 @@ broadcasted(::DefaultArrayStyle{1}, ::typeof(+), r1::AbstractRange, r2::Abstract

broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractRange, x::Number) = range(first(r) - x, step=step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::AbstractRange) = range(x - first(r), step=-step(r), length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange, x::Real) = range(first(r) - x, last(r) - x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Real, r::OrdinalRange) = range(x - first(r), x - last(r), step=-step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Real) = range(first(r) - x, last(r) - x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::OrdinalRange, x::Integer) = range(first(r) - x, last(r) - x, step=step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Integer, r::OrdinalRange) = range(x - first(r), x - last(r), step=-step(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Integer) = range(first(r) - x, last(r) - x)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::AbstractUnitRange, x::Real) = range(first(r) - x, length=length(r))
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), r::StepRangeLen{T}, x::Number) where T =
StepRangeLen{typeof(T(r.ref)-x)}(r.ref - x, r.step, length(r), r.offset)
broadcasted(::DefaultArrayStyle{1}, ::typeof(-), x::Number, r::StepRangeLen{T}) where T =
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1316,7 +1316,7 @@ function iterate(compact::IncrementalCompact, (idx, active_bb)::Tuple{Int, Int}=
compact.result[old_result_idx][:inst]), (compact.idx, active_bb)
end

function maybe_erase_unused!(extra_worklist, compact, idx, callback = x->nothing)
function maybe_erase_unused!(extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int, callback = x::SSAValue->nothing)
stmt = compact.result[idx][:inst]
stmt === nothing && return false
if compact_exprtype(compact, SSAValue(idx)) === Bottom
Expand Down
90 changes: 59 additions & 31 deletions base/compiler/ssair/passes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ function compute_value_for_use(ir::IRCode, domtree::DomTree, allblocks::Vector{I
end
end

function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#), pi_callback=(pi, idx)->false)
function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
callback = (@nospecialize(pi), @nospecialize(idx)) -> false)
while true
if isa(defssa, OldSSAValue)
if already_inserted(compact, defssa)
Expand All @@ -124,7 +125,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
def = compact[defssa]
if isa(def, PiNode)
if pi_callback(def, defssa)
if callback(def, defssa)
return defssa
end
def = def.val
Expand All @@ -135,7 +136,7 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
defssa = def
elseif isa(def, AnySSAValue)
pi_callback(def, defssa)
callback(def, defssa)
if isa(def, SSAValue)
is_old(compact, defssa) && (def = OldSSAValue(def.id))
end
Expand All @@ -148,12 +149,15 @@ function simple_walk(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSA
end
end

function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defidx), @nospecialize(typeconstraint) = types(compact)[defidx])
function simple_walk_constraint(compact::IncrementalCompact, @nospecialize(defssa#=::AnySSAValue=#),
@nospecialize(typeconstraint) = types(compact)[defssa])
callback = function (@nospecialize(pi), @nospecialize(idx))
isa(pi, PiNode) && (typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ)))
if isa(pi, PiNode)
typeconstraint = typeintersect(typeconstraint, widenconst(pi.typ))
end
return false
end
def = simple_walk(compact, defidx, callback)
def = simple_walk(compact, defssa, callback)
return Pair{Any, Any}(def, typeconstraint)
end

Expand Down Expand Up @@ -273,8 +277,10 @@ function is_getfield_captures(@nospecialize(def), compact::IncrementalCompact)
return oc Core.OpaqueClosure
end

function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# try to compute lifted values that can replace `getfield(x, field)` call
# where `x` is an immutable struct that are defined at any of `leaves`
function lift_leaves(compact::IncrementalCompact,
@nospecialize(result_t), field::Int, leaves::Vector{Any})
# For every leaf, the lifted value
lifted_leaves = IdDict{Any, Any}()
maybe_undef = false
Expand Down Expand Up @@ -396,13 +402,13 @@ function lift_leaves(compact::IncrementalCompact, @nospecialize(stmt),
elseif isa(leaf, Union{Argument, Expr})
return nothing
end
!ismutable(leaf) || return nothing
ismutable(leaf) && return nothing
isdefined(leaf, field) || return nothing
val = getfield(leaf, field)
is_inlineable_constant(val) || return nothing
lifted_leaves[leaf_key] = RefValue{Any}(quoted(val))
end
lifted_leaves, maybe_undef
return lifted_leaves, maybe_undef
end

make_MaybeUndef(@nospecialize(typ)) = isa(typ, MaybeUndef) ? typ : MaybeUndef(typ)
Expand All @@ -415,13 +421,11 @@ function lift_comparison!(compact::IncrementalCompact, idx::Int,
typeconstraint = widenconst(c2)
val = stmt.args[3]
else
cmp = c2
cmp = c2::Const
typeconstraint = widenconst(c1)
val = stmt.args[2]
end

is_type_only = isdefined(typeof(cmp), :instance)

if isa(val, Union{OldSSAValue, SSAValue})
val, typeconstraint = simple_walk_constraint(compact, val, typeconstraint)
end
Expand Down Expand Up @@ -497,7 +501,7 @@ function perform_lifting!(compact::IncrementalCompact,
if is_old(compact, old_node_ssa) && isa(val, SSAValue)
val = OldSSAValue(val.id)
end
if isa(val, Union{NewSSAValue, SSAValue, OldSSAValue})
if isa(val, AnySSAValue)
val = simple_walk(compact, val)
end
if val in keys(lifted_leaves)
Expand All @@ -508,11 +512,12 @@ function perform_lifting!(compact::IncrementalCompact,
continue
end
lifted_val = lifted_val.x
if isa(lifted_val, Union{NewSSAValue, SSAValue, OldSSAValue})
lifted_val = simple_walk(compact, lifted_val, (pi, idx)->true)
if isa(lifted_val, AnySSAValue)
callback = (@nospecialize(pi), @nospecialize(idx)) -> true
lifted_val = simple_walk(compact, lifted_val, callback)
end
push!(new_node.values, lifted_val)
elseif isa(val, Union{NewSSAValue, SSAValue, OldSSAValue}) && val in keys(reverse_mapping)
elseif isa(val, AnySSAValue) && val in keys(reverse_mapping)
push!(new_node.edges, edge)
push!(new_node.values, lifted_phis[reverse_mapping[val]].ssa)
else
Expand All @@ -532,14 +537,31 @@ function perform_lifting!(compact::IncrementalCompact,

if stmt_val in keys(lifted_leaves)
stmt_val = lifted_leaves[stmt_val]
elseif isa(stmt_val, Union{NewSSAValue, SSAValue, OldSSAValue}) && stmt_val in keys(reverse_mapping)
elseif isa(stmt_val, AnySSAValue) && stmt_val in keys(reverse_mapping)
stmt_val = RefValue{Any}(lifted_phis[reverse_mapping[stmt_val]].ssa)
end

return stmt_val
end

assertion_counter = 0
"""
getfield_elim_pass!(ir::IRCode) -> newir::IRCode
`getfield` elimination pass, a.k.a. Scalar Replacements of Aggregates optimization.
This pass is based on a local alias analysis that collects field information by def-use chain walking.
It looks for struct allocation sites ("definitions"), and `getfield` calls as well as
`:foreigncall`s that preserve the structs ("usages"). If "definitions" have enough information,
then this pass will replace corresponding usages with lifted values.
`mutable struct`s require additional cares and need to be handled separately from immutables.
For `mutable struct`s, `setfield!` calls account for "definitions" also, and the pass should
give up the lifting conservatively when there are any "intermediate usages" that may escape
the mutable struct (e.g. non-inlined generic function call that takes the mutable struct as
its argument).
In a case when all usages are fully eliminated, `struct` allocation may also be erased as
a result of dead code elimination.
"""
function getfield_elim_pass!(ir::IRCode)
compact = IncrementalCompact(ir)
insertions = Vector{Any}()
Expand All @@ -554,7 +576,6 @@ function getfield_elim_pass!(ir::IRCode)
result_t = compact_exprtype(compact, SSAValue(idx))
is_getfield = is_setfield = false
field_ordering = :unspecified
is_ccall = false
# Step 1: Check whether the statement we're looking at is a getfield/setfield!
if is_known_call(stmt, setfield!, compact)
is_setfield = true
Expand Down Expand Up @@ -610,8 +631,8 @@ function getfield_elim_pass!(ir::IRCode)
old_preserves = stmt.args[(6+nccallargs):end]
for (pidx, preserved_arg) in enumerate(old_preserves)
isa(preserved_arg, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand Down Expand Up @@ -670,8 +691,8 @@ function getfield_elim_pass!(ir::IRCode)

if ismutabletype(struct_typ)
isa(def, SSAValue) || continue
let intermediaries = IdSet()
callback = function(@nospecialize(pi), ssa::AnySSAValue)
let intermediaries = IdSet{Int}()
callback = function (@nospecialize(pi), @nospecialize(ssa))
push!(intermediaries, ssa.id)
return false
end
Expand All @@ -691,6 +712,8 @@ function getfield_elim_pass!(ir::IRCode)
continue
end

# perform SROA on immutable structs here on

if isa(def, Union{OldSSAValue, SSAValue})
def, typeconstraint = simple_walk_constraint(compact, def, typeconstraint)
end
Expand All @@ -703,7 +726,7 @@ function getfield_elim_pass!(ir::IRCode)
field = try_compute_fieldidx(struct_typ, field)
field === nothing && continue

r = lift_leaves(compact, stmt, result_t, field, leaves)
r = lift_leaves(compact, result_t, field, leaves)
r === nothing && continue
lifted_leaves, any_undef = r

Expand Down Expand Up @@ -736,14 +759,13 @@ function getfield_elim_pass!(ir::IRCode)
@assert val !== nothing
end

global assertion_counter
assertion_counter::Int += 1
# global assertion_counter
# assertion_counter::Int += 1
#insert_node_here!(compact, Expr(:assert_egal, Symbol(string("assert_egal_", assertion_counter)), SSAValue(idx), val), nothing, 0, true)
#continue
compact[idx] = val === nothing ? nothing : val.x
end


non_dce_finish!(compact)
# Copy the use count, `simple_dce!` may modify it and for our predicate
# below we need it consistent with the state of the IR here (after tracking
Expand Down Expand Up @@ -874,11 +896,12 @@ function getfield_elim_pass!(ir::IRCode)
end
ir
end
# assertion_counter = 0

function adce_erase!(phi_uses::Vector{Int}, extra_worklist::Vector{Int}, compact::IncrementalCompact, idx::Int)
# return whether this made a change
if isa(compact.result[idx][:inst], PhiNode)
return maybe_erase_unused!(extra_worklist, compact, idx, val -> phi_uses[val.id] -= 1)
return maybe_erase_unused!(extra_worklist, compact, idx, val::SSAValue -> phi_uses[val.id] -= 1)
else
return maybe_erase_unused!(extra_worklist, compact, idx)
end
Expand All @@ -893,7 +916,7 @@ function count_uses(@nospecialize(stmt), uses::Vector{Int})
end
end

function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
function mark_phi_cycles!(compact::IncrementalCompact, safe_phis::BitSet, phi::Int)
worklist = Int[]
push!(worklist, phi)
while !isempty(worklist)
Expand All @@ -909,6 +932,11 @@ function mark_phi_cycles(compact::IncrementalCompact, safe_phis::BitSet, phi::In
end
end

"""
adce_pass!(ir::IRCode) -> newir::IRCode
Aggressive Dead Code Elimination pass to eliminate code.
"""
function adce_pass!(ir::IRCode)
phi_uses = fill(0, length(ir.stmts) + length(ir.new_nodes))
all_phis = Int[]
Expand Down Expand Up @@ -940,7 +968,7 @@ function adce_pass!(ir::IRCode)
for phi in all_phis
# Save any phi cycles that have non-phi uses
if compact.used_ssas[phi] - phi_uses[phi] != 0
mark_phi_cycles(compact, safe_phis, phi)
mark_phi_cycles!(compact, safe_phis, phi)
end
end
for phi in all_phis
Expand Down
4 changes: 4 additions & 0 deletions base/compiler/ssair/verify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,10 @@ function verify_ir(ir::IRCode, print::Bool=true)
@verify_error "SSAValue as assignment LHS"
error("")
end
if stmt.args[2] isa GlobalRef
# undefined GlobalRef as assignment RHS is OK
continue
end
elseif stmt.head === :gc_preserve_end
# We allow gc_preserve_end tokens to span across try/catch
# blocks, which isn't allowed for regular SSA values, so
Expand Down
27 changes: 13 additions & 14 deletions base/compiler/typelimits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,25 @@ function tmerge(@nospecialize(typea), @nospecialize(typeb))
return Bool
end
# type-lattice for Const and PartialStruct wrappers
if (isa(typea, PartialStruct) || isa(typea, Const)) &&
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
widenconst(typea) === widenconst(typeb)
if ((isa(typea, PartialStruct) || isa(typea, Const)) &&
(isa(typeb, PartialStruct) || isa(typeb, Const)) &&
widenconst(typea) === widenconst(typeb))

typea_nfields = nfields_tfunc(typea)
typeb_nfields = nfields_tfunc(typeb)
if !isa(typea_nfields, Const) || !isa(typeb_nfields, Const) || typea_nfields.val !== typeb_nfields.val
typea_nfields = nfields_tfunc(typea)
typeb_nfields = nfields_tfunc(typeb)
if !isa(typea_nfields, Const) || !isa(typeb_nfields, Const) || typea_nfields.val !== typeb_nfields.val
return widenconst(typea)
end
end

type_nfields = typea_nfields.val::Int
fields = Vector{Any}(undef, type_nfields)
anyconst = false
for i = 1:type_nfields
type_nfields = typea_nfields.val::Int
fields = Vector{Any}(undef, type_nfields)
anyconst = false
for i = 1:type_nfields
fields[i] = tmerge(getfield_tfunc(typea, Const(i)),
getfield_tfunc(typeb, Const(i)))
anyconst |= has_nontrivial_const_info(fields[i])
end
return anyconst ? PartialStruct(widenconst(typea), fields) :
widenconst(typea)
end
return anyconst ? PartialStruct(widenconst(typea), fields) : widenconst(typea)
end
if isa(typea, PartialOpaque) && isa(typeb, PartialOpaque) && widenconst(typea) == widenconst(typeb)
if !(typea.source === typeb.source &&
Expand Down
14 changes: 8 additions & 6 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -592,8 +592,8 @@ IteratorSize(::Type{<:Rest{I}}) where {I} = rest_iteratorsize(IteratorSize(I))

# Count -- infinite counting

struct Count{S<:Number}
start::S
struct Count{T,S}
start::T
step::S
end

Expand All @@ -613,11 +613,13 @@ julia> for v in Iterators.countfrom(5, 2)
9
```
"""
countfrom(start::Number, step::Number) = Count(promote(start, step)...)
countfrom(start::Number) = Count(start, oneunit(start))
countfrom() = Count(1, 1)
countfrom(start::T, step::S) where {T,S} = Count{typeof(start+step),S}(start, step)
countfrom(start::Number, step::Number) = Count(promote(start, step)...)
countfrom(start) = Count(start, oneunit(start))
countfrom() = Count(1, 1)

eltype(::Type{Count{S}}) where {S} = S

eltype(::Type{Count{T,S}}) where {T,S} = T

iterate(it::Count, state=it.start) = (state, state + it.step)

Expand Down
Loading

0 comments on commit d2914a9

Please sign in to comment.