From 91151ab871c7e7d6689d1cfa793c12062d37d6b6 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 29 Mar 2019 14:53:52 -0400 Subject: [PATCH] fix overflow in CartesianIndices iteration (#31011) This allows LLVM to vectorize the 1D CartesianIndices case, as well as fixing an overflow bug for: ```julia CartesianIndices(((typemax(Int64)-2):typemax(Int64),)) ``` Co-authored-by: Yingbo Ma Co-Authored-By: vchuravy --- base/array.jl | 30 ++++++++++++------- base/multidimensional.jl | 64 +++++++++++++++++++++++++++++----------- test/cartesian.jl | 49 ++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 28 deletions(-) diff --git a/base/array.jl b/base/array.jl index 119acfd71ae70..a80f481388158 100644 --- a/base/array.jl +++ b/base/array.jl @@ -1604,10 +1604,11 @@ CartesianIndex(2, 1) function findnext(A, start) l = last(keys(A)) i = start - while i <= l - if A[i] - return i - end + i > l && return nothing + while true + A[i] && return i + i == l && break + # nextind(A, l) can throw/overflow i = nextind(A, i) end return nothing @@ -1685,10 +1686,11 @@ CartesianIndex(1, 1) function findnext(testf::Function, A, start) l = last(keys(A)) i = start - while i <= l - if testf(A[i]) - return i - end + i > l && return nothing + while true + testf(A[i]) && return i + i == l && break + # nextind(A, l) can throw/overflow i = nextind(A, i) end return nothing @@ -1781,8 +1783,12 @@ CartesianIndex(2, 1) """ function findprev(A, start) i = start - while i >= first(keys(A)) + f = first(keys(A)) + i < f && return nothing + while true A[i] && return i + i == f && break + # prevind(A, f) can throw/underflow i = prevind(A, i) end return nothing @@ -1868,8 +1874,12 @@ CartesianIndex(2, 1) """ function findprev(testf::Function, A, start) i = start - while i >= first(keys(A)) + f = first(keys(A)) + i < f && return nothing + while true testf(A[i]) && return i + i == f && break + # prevind(A, f) can throw/underflow i = prevind(A, i) end return nothing diff --git a/base/multidimensional.jl b/base/multidimensional.jl index f03e641af10fc..4b9a3dd747bf1 100644 --- a/base/multidimensional.jl +++ b/base/multidimensional.jl @@ -95,6 +95,9 @@ module IteratorsMD # access to index tuple Tuple(index::CartesianIndex) = index.I + # equality + Base.:(==)(a::CartesianIndex{N}, b::CartesianIndex{N}) where N = a.I == b.I + # zeros and ones zero(::CartesianIndex{N}) where {N} = zero(CartesianIndex{N}) zero(::Type{CartesianIndex{N}}) where {N} = CartesianIndex(ntuple(x -> 0, Val(N))) @@ -142,11 +145,15 @@ module IteratorsMD # nextind and prevind with CartesianIndex function Base.nextind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N} iter = CartesianIndices(axes(a)) - return CartesianIndex(inc(i.I, first(iter).I, last(iter).I)) + # might overflow + I = inc(i.I, first(iter).I, last(iter).I) + return I end function Base.prevind(a::AbstractArray{<:Any,N}, i::CartesianIndex{N}) where {N} iter = CartesianIndices(axes(a)) - return CartesianIndex(dec(i.I, last(iter).I, first(iter).I)) + # might underflow + I = dec(i.I, last(iter).I, first(iter).I) + return I end # Iteration over the elements of CartesianIndex cannot be supported until its length can be inferred, @@ -334,20 +341,30 @@ module IteratorsMD iterfirst, iterfirst end @inline function iterate(iter::CartesianIndices, state) - nextstate = CartesianIndex(inc(state.I, first(iter).I, last(iter).I)) - nextstate.I[end] > last(iter.indices[end]) && return nothing - nextstate, nextstate + valid, I = __inc(state.I, first(iter).I, last(iter).I) + valid || return nothing + return CartesianIndex(I...), CartesianIndex(I...) end # increment & carry - @inline inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = () - @inline inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]+1,) @inline function inc(state, start, stop) + _, I = __inc(state, start, stop) + return CartesianIndex(I...) + end + + # increment post check to avoid integer overflow + @inline __inc(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, () + @inline function __inc(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) + valid = state[1] < stop[1] + return valid, (state[1]+1,) + end + + @inline function __inc(state, start, stop) if state[1] < stop[1] - return (state[1]+1,tail(state)...) + return true, (state[1]+1, tail(state)...) end - newtail = inc(tail(state), tail(start), tail(stop)) - (start[1], newtail...) + valid, I = __inc(tail(state), tail(start), tail(stop)) + return valid, (start[1], I...) end # 0-d cartesian ranges are special-cased to iterate once and only once @@ -414,21 +431,32 @@ module IteratorsMD iterfirst, iterfirst end @inline function iterate(r::Reverse{<:CartesianIndices}, state) - nextstate = CartesianIndex(dec(state.I, last(r.itr).I, first(r.itr).I)) - nextstate.I[end] < first(r.itr.indices[end]) && return nothing - nextstate, nextstate + valid, I = __dec(state.I, last(r.itr).I, first(r.itr).I) + valid || return nothing + return CartesianIndex(I...), CartesianIndex(I...) end # decrement & carry - @inline dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = () - @inline dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) = (state[1]-1,) @inline function dec(state, start, stop) + _, I = __dec(state, start, stop) + return CartesianIndex(I...) + end + + # decrement post check to avoid integer overflow + @inline __dec(::Tuple{}, ::Tuple{}, ::Tuple{}) = false, () + @inline function __dec(state::Tuple{Int}, start::Tuple{Int}, stop::Tuple{Int}) + valid = state[1] > stop[1] + return valid, (state[1]-1,) + end + + @inline function __dec(state, start, stop) if state[1] > stop[1] - return (state[1]-1,tail(state)...) + return true, (state[1]-1, tail(state)...) end - newtail = dec(tail(state), tail(start), tail(stop)) - (start[1], newtail...) + valid, I = __dec(tail(state), tail(start), tail(stop)) + return valid, (start[1], I...) end + # 0-d cartesian ranges are special-cased to iterate once and only once iterate(iter::Reverse{<:CartesianIndices{0}}, state=false) = state ? nothing : (CartesianIndex(), true) diff --git a/test/cartesian.jl b/test/cartesian.jl index 40badf6bb24bb..7de79bc6a407b 100644 --- a/test/cartesian.jl +++ b/test/cartesian.jl @@ -15,3 +15,52 @@ ex = Base.Cartesian.exprresolve(:(if 5 > 4; :x; else :y; end)) # can't convert higher-dimensional indices to Int @test_throws MethodError convert(Int, CartesianIndex(42, 1)) end + +@testset "CartesianIndices overflow" begin + I = CartesianIndices((1:typemax(Int),)) + i = last(I) + @test iterate(I, i) === nothing + + I = CartesianIndices((1:(typemax(Int)-1),)) + i = CartesianIndex(typemax(Int)) + @test iterate(I, i) === nothing + + I = CartesianIndices((1:typemax(Int), 1:typemax(Int))) + i = last(I) + @test iterate(I, i) === nothing + + i = CartesianIndex(typemax(Int), 1) + @test iterate(I, i) === (CartesianIndex(1, 2), CartesianIndex(1,2)) + + # reverse cartesian indices + I = CartesianIndices((typemin(Int):(typemin(Int)+3),)) + i = last(I) + @test iterate(I, i) === nothing +end + +@testset "CartesianIndices iteration" begin + I = CartesianIndices((2:4, 0:1, 1:1, 3:5)) + indices = Vector{eltype(I)}() + for i in I + push!(indices, i) + end + @test length(I) == length(indices) + @test vec(I) == indices + + empty!(indices) + I = Iterators.reverse(I) + for i in I + push!(indices, i) + end + @test length(I) == length(indices) + @test vec(collect(I)) == indices + + # test invalid state + I = CartesianIndices((2:4, 3:5)) + @test iterate(I, CartesianIndex(typemax(Int), 3))[1] == CartesianIndex(2,4) + @test iterate(I, CartesianIndex(typemax(Int), 4))[1] == CartesianIndex(2,5) + @test iterate(I, CartesianIndex(typemax(Int), 5)) === nothing + + @test iterate(I, CartesianIndex(3, typemax(Int)))[1] == CartesianIndex(4,typemax(Int)) + @test iterate(I, CartesianIndex(4, typemax(Int))) === nothing +end