Skip to content

Commit

Permalink
Extend strides(::ReshapedArray) with non-contiguous strided parent
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 committed Mar 9, 2022
1 parent 777910d commit d07b00e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 11 deletions.
48 changes: 43 additions & 5 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,52 @@ unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::AbstractArray) = false
_checkcontiguous(::Type{Bool}, A::Array) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
if _checkcontiguous(Bool, a)
# simpify runtime convert when possibe
return size_to_strides(1, size(a)...)
end
apsz::Dims = size(a.parent)
apst::Dims = strides(a.parent)
msz, mst, n = merge_adjacent_dim(apsz, apst) # Try to perform "lazy" reshape
n == ndims(a.parent) && return size_to_strides(mst, size(a)...) # Parent is stridevector like
return _reshaped_strides(size(a), 1, msz, mst, n, apsz, apst)
end

function _reshaped_strides(::Dims{0}, reshaped::Int, msz::Int, ::Int, ::Int, ::Dims, ::Dims)
reshaped == msz || throw(ArgumentError("Input is not strided."))
()
end
function _reshaped_strides(sz::Dims, reshaped::Int, msz::Int, mst::Int, n::Int, apsz::Dims, apst::Dims)
st = reshaped * mst
reshaped = reshaped * sz[1]
if length(sz) > 1 && reshaped == msz && sz[2] != 1
msz, mst, n = merge_adjacent_dim(apsz, apst, n + 1)
reshaped = 1
end
sts = _reshaped_strides(tail(sz), reshaped, msz, mst, n, apsz, apst)
return (st, sts...)
end

merge_adjacent_dim(::Dims{0}, ::Dims{0}) = 1, 1, 0
merge_adjacent_dim(apsz::Dims{1}, apst::Dims{1}) = apsz[1], apst[1], 1
function merge_adjacent_dim(apsz::Dims{N}, apst::Dims{N}, n::Int = 1) where {N}
sz, st = apsz[n], apst[n]
while n < N
szₙ, stₙ = apsz[n+1], apst[n+1]
if sz == 1
sz, st = szₙ, stₙ
elseif stₙ == st * sz || szₙ == 1
sz *= szₙ
else
break
end
n += 1
end
return sz, st, n
end
44 changes: 38 additions & 6 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1567,22 +1567,54 @@ end
@test reshape(r, :) === reshape(r, (:,)) === r
end

struct FakeZeroDimArray <: AbstractArray{Int, 0} end
Base.strides(::FakeZeroDimArray) = ()
Base.size(::FakeZeroDimArray) = ()
@testset "strides for ReshapedArray" begin
# Type-based contiguous check is tested in test/compiler/inline.jl
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test strides(vec(a)) == (1,)
@test check_strides(vec(a))
b = view(parent(a), 1:9, 1:10)
@test_throws "Parent must be contiguous." strides(vec(b))
@test_throws "Input is not strided." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
@test check_strides(reshape(a, 3, 4, 5))
@test check_strides(reshape(a, 5, 6, 2))
b = view(parent(a), 60n:-n:1)
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
@test check_strides(reshape(b, 3, 4, 5))
@test check_strides(reshape(b, 5, 6, 2))
end
# StridedVector like parent
a = randn(10, 10, 10)
b = view(a, 1:10, 1:1, 5:5)
@test check_strides(reshape(b, 2, 5))
# Other StridedArray parent
a = view(randn(10,10), 1:9, 1:10)
@test check_strides(reshape(a,3,3,2,5))
@test check_strides(reshape(a,3,3,5,2))
@test check_strides(reshape(a,9,5,2))
@test check_strides(reshape(a,3,3,10))
@test check_strides(reshape(a,1,3,1,3,1,5,1,2))
@test check_strides(reshape(a,3,3,5,1,1,2,1,1))
@test_throws "Input is not strided." strides(reshape(a,3,6,5))
@test_throws "Input is not strided." strides(reshape(a,3,2,3,5))
@test_throws "Input is not strided." strides(reshape(a,3,5,3,2))
@test_throws "Input is not strided." strides(reshape(a,5,3,3,2))
# Zero dimensional parent
a = reshape(FakeZeroDimArray(),1,1,1)
@test @inferred(strides(a)) == (1, 1, 1)
end

@testset "stride for 0 dims array #44087" begin
Expand Down

0 comments on commit d07b00e

Please sign in to comment.