Skip to content

Commit

Permalink
fix overflow in CartesianIndices iteration (#31011)
Browse files Browse the repository at this point in the history
This allows LLVM to vectorize the 1D CartesianIndices case,
as well as fixing an overflow bug for:

```julia
CartesianIndices(((typemax(Int64)-2):typemax(Int64),))
```

Co-authored-by: Yingbo Ma <mayingbo5@gmail.com>
Co-Authored-By: vchuravy <vchuravy@users.noreply.github.com>
  • Loading branch information
vchuravy and YingboMa authored Mar 29, 2019
1 parent a399780 commit 91151ab
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 28 deletions.
30 changes: 20 additions & 10 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1604,10 +1604,11 @@ CartesianIndex(2, 1)
function findnext(A, start)
l = last(keys(A))
i = start
while i <= l
if A[i]
return i
end
i > l && return nothing
while true
A[i] && return i
i == l && break
# nextind(A, l) can throw/overflow
i = nextind(A, i)
end
return nothing
Expand Down Expand Up @@ -1685,10 +1686,11 @@ CartesianIndex(1, 1)
function findnext(testf::Function, A, start)
l = last(keys(A))
i = start
while i <= l
if testf(A[i])
return i
end
i > l && return nothing
while true
testf(A[i]) && return i
i == l && break
# nextind(A, l) can throw/overflow
i = nextind(A, i)
end
return nothing
Expand Down Expand Up @@ -1781,8 +1783,12 @@ CartesianIndex(2, 1)
"""
function findprev(A, start)
i = start
while i >= first(keys(A))
f = first(keys(A))
i < f && return nothing
while true
A[i] && return i
i == f && break
# prevind(A, f) can throw/underflow
i = prevind(A, i)
end
return nothing
Expand Down Expand Up @@ -1868,8 +1874,12 @@ CartesianIndex(2, 1)
"""
function findprev(testf::Function, A, start)
i = start
while i >= first(keys(A))
f = first(keys(A))
i < f && return nothing
while true
testf(A[i]) && return i
i == f && break
# prevind(A, f) can throw/underflow
i = prevind(A, i)
end
return nothing
Expand Down
64 changes: 46 additions & 18 deletions base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ module IteratorsMD
# access to index tuple
Tuple(index::CartesianIndex) = index.I

# equality
Base.:(==)(a::CartesianIndex{N}, b::CartesianIndex{N}) where N = a.I == b.I

# zeros and ones
zero(::CartesianIndex{N}) where {N} = zero(CartesianIndex{N})
zero(::Type{CartesianIndex{N}}) where {N} = CartesianIndex(ntuple(x -> 0, Val(N)))
Expand Down Expand Up @@ -142,11 +145,15 @@ module IteratorsMD
# nextind and prevind with CartesianIndex
function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
iter = CartesianIndices(axes(a))
return CartesianIndex(inc(i.I, first(iter).I, last(iter).I))
# might overflow
I = inc(i.I, first(iter).I, last(iter).I)
return I
end
function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N}
iter = CartesianIndices(axes(a))
return CartesianIndex(dec(i.I, last(iter).I, first(iter).I))
# might underflow
I = dec(i.I, last(iter).I, first(iter).I)
return I
end

# Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred,
Expand Down Expand Up @@ -334,20 +341,30 @@ module IteratorsMD
iterfirst, iterfirst
end
@inline function iterate(iter::CartesianIndices, state)
nextstate = CartesianIndex(inc(state.I, first(iter).I, last(iter).I))
nextstate.I[end] > last(iter.indices[end]) && return nothing
nextstate, nextstate
valid, I = __inc(state.I, first(iter).I, last(iter).I)
valid || return nothing
return CartesianIndex(I...), CartesianIndex(I...)
end

# increment & carry
@inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
@inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,)
@inline function inc(state, start, stop)
_, I = __inc(state, start, stop)
return CartesianIndex(I...)
end

# increment post check to avoid integer overflow
@inline __inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
@inline function __inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
valid = state[1] < stop[1]
return valid, (state[1]+1,)
end

@inline function __inc(state, start, stop)
if state[1] < stop[1]
return (state[1]+1,tail(state)...)
return true, (state[1]+1, tail(state)...)
end
newtail = inc(tail(state), tail(start), tail(stop))
(start[1], newtail...)
valid, I = __inc(tail(state), tail(start), tail(stop))
return valid, (start[1], I...)
end

# 0-d cartesian ranges are special-cased to iterate once and only once
Expand Down Expand Up @@ -414,21 +431,32 @@ module IteratorsMD
iterfirst, iterfirst
end
@inline function iterate(r::Reverse{<:CartesianIndices}, state)
nextstate = CartesianIndex(dec(state.I, last(r.itr).I, first(r.itr).I))
nextstate.I[end] < first(r.itr.indices[end]) && return nothing
nextstate, nextstate
valid, I = __dec(state.I, last(r.itr).I, first(r.itr).I)
valid || return nothing
return CartesianIndex(I...), CartesianIndex(I...)
end

# decrement & carry
@inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = ()
@inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]-1,)
@inline function dec(state, start, stop)
_, I = __dec(state, start, stop)
return CartesianIndex(I...)
end

# decrement post check to avoid integer overflow
@inline __dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, ()
@inline function __dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int})
valid = state[1] > stop[1]
return valid, (state[1]-1,)
end

@inline function __dec(state, start, stop)
if state[1] > stop[1]
return (state[1]-1,tail(state)...)
return true, (state[1]-1, tail(state)...)
end
newtail = dec(tail(state), tail(start), tail(stop))
(start[1], newtail...)
valid, I = __dec(tail(state), tail(start), tail(stop))
return valid, (start[1], I...)
end

# 0-d cartesian ranges are special-cased to iterate once and only once
iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true)

Expand Down
49 changes: 49 additions & 0 deletions test/cartesian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,52 @@ ex = Base.Cartesian.exprresolve(:(if 5 > 4; :x; else :y; end))
# can't convert higher-dimensional indices to Int
@test_throws MethodError convert(Int, CartesianIndex(42, 1))
end

@testset "CartesianIndices overflow" begin
I = CartesianIndices((1:typemax(Int),))
i = last(I)
@test iterate(I, i) === nothing

I = CartesianIndices((1:(typemax(Int)-1),))
i = CartesianIndex(typemax(Int))
@test iterate(I, i) === nothing

I = CartesianIndices((1:typemax(Int), 1:typemax(Int)))
i = last(I)
@test iterate(I, i) === nothing

i = CartesianIndex(typemax(Int), 1)
@test iterate(I, i) === (CartesianIndex(1, 2), CartesianIndex(1,2))

# reverse cartesian indices
I = CartesianIndices((typemin(Int):(typemin(Int)+3),))
i = last(I)
@test iterate(I, i) === nothing
end

@testset "CartesianIndices iteration" begin
I = CartesianIndices((2:4, 0:1, 1:1, 3:5))
indices = Vector{eltype(I)}()
for i in I
push!(indices, i)
end
@test length(I) == length(indices)
@test vec(I) == indices

empty!(indices)
I = Iterators.reverse(I)
for i in I
push!(indices, i)
end
@test length(I) == length(indices)
@test vec(collect(I)) == indices

# test invalid state
I = CartesianIndices((2:4, 3:5))
@test iterate(I, CartesianIndex(typemax(Int), 3))[1] == CartesianIndex(2,4)
@test iterate(I, CartesianIndex(typemax(Int), 4))[1] == CartesianIndex(2,5)
@test iterate(I, CartesianIndex(typemax(Int), 5)) === nothing

@test iterate(I, CartesianIndex(3, typemax(Int)))[1] == CartesianIndex(4,typemax(Int))
@test iterate(I, CartesianIndex(4, typemax(Int))) === nothing
end

0 comments on commit 91151ab

Please sign in to comment.