Skip to content

Commit

Permalink
Fix @simd for non 1 step CartesianPartition (#42736)
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 authored Feb 25, 2022
1 parent 9af12d3 commit 2e2c16a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 35 deletions.
6 changes: 3 additions & 3 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -973,14 +973,14 @@ end
destc = dest.chunks
cind = 1
bc′ = preprocess(dest, bc)
for P in Iterators.partition(eachindex(bc′), bitcache_size)
@inbounds for P in Iterators.partition(eachindex(bc′), bitcache_size)
ind = 1
@simd for I in P
@inbounds tmp[ind] = bc′[I]
tmp[ind] = bc′[I]
ind += 1
end
@simd for i in ind:bitcache_size
@inbounds tmp[i] = false
tmp[i] = false
end
dumpbitcache(destc, cind, tmp)
cind += bitcache_chunks
Expand Down
64 changes: 37 additions & 27 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,8 @@ module IteratorsMD
simd_inner_length(iter::CartesianIndices, I::CartesianIndex) = Base.length(iter.indices[1])

simd_index(iter::CartesianIndices{0}, ::CartesianIndex, I1::Int) = first(iter)
@propagate_inbounds function simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int)
CartesianIndex(getindex(iter.indices[1], I1+first(Base.axes1(iter.indices[1]))), Ilast.I...)
end
@propagate_inbounds simd_index(iter::CartesianIndices, Ilast::CartesianIndex, I1::Int) =
CartesianIndex(iter.indices[1][I1+firstindex(iter.indices[1])], Ilast)

# Split out the first N elements of a tuple
@inline function split(t, V::Val)
Expand Down Expand Up @@ -585,7 +584,7 @@ module IteratorsMD
CartesianIndices(intersect.(a.indices, b.indices))

# Views of reshaped CartesianIndices are used for partitions — ensure these are fast
const CartesianPartition{T<:CartesianIndex, P<:CartesianIndices, R<:ReshapedArray{T,1,P}} = SubArray{T,1,R,Tuple{UnitRange{Int}},false}
const CartesianPartition{T<:CartesianIndex, P<:CartesianIndices, R<:ReshapedArray{T,1,P}} = SubArray{T,1,R,<:Tuple{AbstractUnitRange{Int}},false}
eltype(::Type{PartitionIterator{T}}) where {T<:ReshapedArrayLF} = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, true}
eltype(::Type{PartitionIterator{T}}) where {T<:ReshapedArray} = SubArray{eltype(T), 1, T, Tuple{UnitRange{Int}}, false}
Iterators.IteratorEltype(::Type{<:PartitionIterator{T}}) where {T<:ReshapedArray} = Iterators.IteratorEltype(T)
Expand All @@ -594,7 +593,6 @@ module IteratorsMD
eltype(::Type{PartitionIterator{T}}) where {T<:Union{UnitRange, StepRange, StepRangeLen, LinRange}} = T
Iterators.IteratorEltype(::Type{<:PartitionIterator{T}}) where {T<:Union{OneTo, UnitRange, StepRange, StepRangeLen, LinRange}} = Iterators.IteratorEltype(T)


@inline function iterate(iter::CartesianPartition)
isempty(iter) && return nothing
f = first(iter)
Expand All @@ -610,33 +608,45 @@ module IteratorsMD
# In general, the Cartesian Partition might start and stop in the middle of the outer
# dimensions — thus the outer range of a CartesianPartition is itself a
# CartesianPartition.
t = tail(iter.parent.parent.indices)
ci = CartesianIndices(t)
li = LinearIndices(t)
return @inbounds view(ci, li[tail(iter[1].I)...]:li[tail(iter[end].I)...])
mi = iter.parent.mi
ci = iter.parent.parent
ax, ax1 = axes(ci), Base.axes1(ci)
subs = Base.ind2sub_rs(ax, mi, first(iter.indices[1]))
vl, fl = Base._sub2ind(tail(ax), tail(subs)...), subs[1]
vr, fr = divrem(last(iter.indices[1]) - 1, mi[end]) .+ (1, first(ax1))
oci = CartesianIndices(tail(ci.indices))
# A fake CartesianPartition to reuse the outer iterate fallback
outer = @inbounds view(ReshapedArray(oci, (length(oci),), mi), vl:vr)
init = @inbounds dec(oci[tail(subs)...].I, oci.indices) # real init state
# Use Generator to make inner loop branchless
@inline function skip_len_I(i::Int, I::CartesianIndex)
l = i == 1 ? fl : first(ax1)
r = i == length(outer) ? fr : last(ax1)
l - first(ax1), r - l + 1, I
end
(skip_len_I(i, I) for (i, I) in Iterators.enumerate(Iterators.rest(outer, (init, 0))))
end
function simd_outer_range(iter::CartesianPartition{CartesianIndex{2}})
@inline function simd_outer_range(iter::CartesianPartition{CartesianIndex{2}})
# But for two-dimensional Partitions the above is just a simple one-dimensional range
# over the second dimension; we don't need to worry about non-rectangular staggers in
# higher dimensions.
return @inbounds CartesianIndices((iter[1][2]:iter[end][2],))
end
@inline function simd_inner_length(iter::CartesianPartition, I::CartesianIndex)
inner = iter.parent.parent.indices[1]
@inbounds fi = iter[1].I
@inbounds li = iter[end].I
inner_start = I.I == tail(fi) ? fi[1] : first(inner)
inner_end = I.I == tail(li) ? li[1] : last(inner)
return inner_end - inner_start + 1
end
@inline function simd_index(iter::CartesianPartition, Ilast::CartesianIndex, I1::Int)
# I1 is the 0-based distance from the first dimension's offest
offset = first(iter.parent.parent.indices[1]) # (this is 1 for 1-based arrays)
# In the first column we need to also add in the iter's starting point (branchlessly)
f = @inbounds iter[1]
startoffset = (Ilast.I == tail(f.I))*(f[1] - 1)
CartesianIndex((I1 + offset + startoffset, Ilast.I...))
mi = iter.parent.mi
ci = iter.parent.parent
ax, ax1 = axes(ci), Base.axes1(ci)
fl, vl = Base.ind2sub_rs(ax, mi, first(iter.indices[1]))
fr, vr = Base.ind2sub_rs(ax, mi, last(iter.indices[1]))
outer = @inbounds CartesianIndices((ci.indices[2][vl:vr],))
# Use Generator to make inner loop branchless
@inline function skip_len_I(I::CartesianIndex{1})
l = I == first(outer) ? fl : first(ax1)
r = I == last(outer) ? fr : last(ax1)
l - first(ax1), r - l + 1, I
end
(skip_len_I(I) for I in outer)
end
@inline simd_inner_length(iter::CartesianPartition, (_, len, _)::Tuple{Int,Int,CartesianIndex}) = len
@propagate_inbounds simd_index(iter::CartesianPartition, (skip, _, I)::Tuple{Int,Int,CartesianIndex}, n::Int) =
simd_index(iter.parent.parent, I, n + skip)
end # IteratorsMD


Expand Down
13 changes: 8 additions & 5 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -550,12 +550,15 @@ end
(1,1), (8,8), (11, 13),
(1,1,1), (8, 4, 2), (11, 13, 17)),
part in (1, 7, 8, 11, 63, 64, 65, 142, 143, 144)
P = partition(CartesianIndices(dims), part)
for I in P
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
for fun in (i -> 1:i, i -> 1:2:2i, i -> Base.IdentityUnitRange(-i:i))
iter = CartesianIndices(map(fun, dims))
P = partition(iter, part)
for I in P
@test length(I) == iterate_length(I) == simd_iterate_length(I) == simd_trip_count(I)
@test collect(I) == iterate_elements(I) == simd_iterate_elements(I) == index_elements(I)
end
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), iter))
end
@test all(Base.splat(==), zip(Iterators.flatten(map(collect, P)), CartesianIndices(dims)))
end
@testset "empty/invalid partitions" begin
@test_throws ArgumentError partition(1:10, 0)
Expand Down

0 comments on commit 2e2c16a

Please sign in to comment.