Skip to content

Commit

Permalink
feat: rehaul iteration and indexing
Browse files Browse the repository at this point in the history
- make A[i::Int, js::Array] return a Vector and support views
- make eachindex(A::MatrixElem) return a rowmajor iterator
- delete iterate(A::MatrixElem) (for now)
  • Loading branch information
thofma committed Jan 29, 2024
1 parent 0ef1f6a commit f79bdb4
Show file tree
Hide file tree
Showing 4 changed files with 199 additions and 21 deletions.
154 changes: 142 additions & 12 deletions src/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ end

_checkbounds(i::Int, j::Int) = 1 <= j <= i

function _checkbounds(A, i::Int, j::Int)
_checkbounds(i::Int, j::AbstractVector{Int}) = all(jj -> 1 <= jj <= i, j)

function _checkbounds(A, i::Union{Int, AbstractVector{Int}}, j::Union{Int, AbstractVector{Int}})
(_checkbounds(nrows(A), i) && _checkbounds(ncols(A), j)) ||
Base.throw_boundserror(A, (i, j))
end
Expand Down Expand Up @@ -386,18 +388,36 @@ function getindex(M::MatElem, rows::AbstractVector{Int}, cols::AbstractVector{In
return A
end

function getindex(M::MatElem, i::Int, cols::AbstractVector{Int})
_checkbounds(M, i, cols)
A = Vector{elem_type(base_ring(M))}(undef, length(cols))
for j in eachindex(cols)
A[j] = deepcopy(M[i, cols[j]])
end
return A
end

function getindex(M::MatElem, rows::AbstractVector{Int}, j::Int)
_checkbounds(M, rows, j)
A = Vector{elem_type(base_ring(M))}(undef, length(rows))
for i in eachindex(rows)
A[i] = deepcopy(M[rows[i], j])
end
return A
end

getindex(M::MatElem,
rows::Union{Int,Colon,AbstractVector{Int}},
cols::Union{Int,Colon,AbstractVector{Int}}) = M[_to_indices(M, rows, cols)...]

function _to_indices(x, rows, cols)
if rows isa Integer
rows = rows:rows
rows = rows
elseif rows isa Colon
rows = 1:nrows(x)
end
if cols isa Integer
cols = cols:cols
cols = cols
elseif cols isa Colon
cols = 1:ncols(x)
end
Expand All @@ -422,6 +442,52 @@ function Base.lastindex(M::MatrixElem{T}, i::Int) where T <: NCRingElement
end
end

##

struct RowMajorIndices{T}
C::T
end

Base.length(t::RowMajorIndices) = Base.length(t.C)

function Base.iterate(c::RowMajorIndices, s)
y = iterate(c.C, s)
if y === nothing
return nothing
end
return CartesianIndex(y[2][2], y[2][1]), y[1]
end

function Base.iterate(c::RowMajorIndices)
y = iterate(c.C)
if y === nothing
return nothing
end
return CartesianIndex(y[2][2], y[2][1]), y[1]
end

function Base.getindex(c::RowMajorIndices, i::Int)
r = c.C[i]
return CartesianIndex(r[2], r[1])
end

Base.IteratorSize(::Type{<:RowMajorIndices}) = Base.HasShape{2}()

Base.eltype(::Type{<:RowMajorIndices}) = CartesianIndex

Base.size(c::RowMajorIndices, dim) = size(c.C, dim == 1 ? 2 : 1)

Base.size(c::RowMajorIndices) = (size(c.C, 2), size(c.C, 1))

Base.CartesianIndices(a::MatrixElem) = RowMajorIndices(CartesianIndices((ncols(a), nrows(a))))

Base.LinearIndices(c::RowMajorIndices) = LinearIndices(c.C)'

# To make collect work
function Base.copyto!(dest::AbstractArray, src::RowMajorIndices)
return Base.copyto!(dest, CartesianIndices((size(src, 1), size(src, 2))))
end

###############################################################################
#
# Array interface
Expand All @@ -432,7 +498,7 @@ Base.ndims(::MatrixElem{T}) where T <: NCRingElement = 2

# Cartesian indexing

Base.eachindex(a::MatrixElem{T}) where T <: NCRingElement = CartesianIndices((nrows(a), ncols(a)))
Base.eachindex(a::MatrixElem{T}) where T <: NCRingElement = RowMajorIndices(CartesianIndices((ncols(a), nrows(a))))

Base.@propagate_inbounds Base.getindex(a::MatrixElem{T}, I::CartesianIndex) where T <: NCRingElement =
a[I[1], I[2]]
Expand Down Expand Up @@ -465,22 +531,86 @@ Base.@propagate_inbounds function setindex!(M::MatrixElem, x, i::Integer)
end
end

# iteration

Base.eltype(::Type{<:MatrixElem{T}}) where {T} = T

Base.length(c::MatrixElem) = length(c)

function Base.iterate(a::MatrixElem, ij=(1, 0))
i, j = ij
j += 1
if j > Base.size(a, 2)
iszero(size(a, 1)) && return nothing
j = 1
i += 1
end
i > size(a, 1) && return nothing
a[i, j], (i, j)
end

# iteration

function Base.iterate(a::MatrixElem{T}, ij=(0, 1)) where T <: NCRingElement
struct _RowMajorIterator{T}
a::T
end

Base.eltype(::Type{_RowMajorIterator{<:MatrixElem{T}}}) where {T} = T

Base.length(c::_RowMajorIterator{T}) where {T} = length(c.a)

function Base.iterate(a::_RowMajorIterator, ij=(1, 0))
i, j = ij
j += 1
if j > Base.size(a.a, 2)
iszero(size(a.a, 1)) && return nothing
j = 1
i += 1
end
i > size(a.a, 1) && return nothing
a.a[i, j], (i, j)
end

#Base.IteratorSize(::Type{<:MatrixElem}) = Base.HasShape{2}()
#Base.IteratorEltype(::Type{<:MatrixElem}) = Base.HasEltype() # default

struct _ColumnMajorIterator{T}
a::T
end

function Base.iterate(a::_ColumnMajorIterator, ij=(0, 1))
i, j = ij
i += 1
if i > nrows(a)
iszero(nrows(a)) && return nothing
if i > Base.size(a.a, 1)
iszero(size(a.a, 1)) && return nothing
i = 1
j += 1
end
j > ncols(a) && return nothing
a[i, j], (i, j)
j > size(a.a, 2) && return nothing
a.a[i, j], (i, j)
end

Base.IteratorSize(::Type{<:MatrixElem}) = Base.HasShape{2}()
Base.IteratorEltype(::Type{<:MatrixElem}) = Base.HasEltype() # default
Base.IteratorEltype(::Type{<:_ColumnMajorIterator}) = Base.HasEltype() # default

Base.eltype(::Type{_ColumnMajorIterator{<:MatrixElem{T}}}) where {T} = T

Base.length(c::_ColumnMajorIterator{T}) where {T} = length(c.a)

# the following is implemented to make collect(::MatrixElem) work
# function Base.copyto!(dest::AbstractArray, src::MatrixElem)
# # this breaks if dest is itself a MatrixElem
# # if this ever happens, we need to adjust the iteration
# @assert !(dest isa MatrixElem)
# destiter = eachindex(dest)
# y = iterate(destiter)
# for x in src#_ColumnMajorIterator(src)
# y === nothing &&
# throw(ArgumentError("destination has fewer elements than required"))
# dest[y[1]] = x
# y = iterate(destiter, y[2])
# end
# return dest
# end

###############################################################################
#
Expand Down Expand Up @@ -2519,7 +2649,7 @@ function trace_of_prod(M::MatElem, N::MatElem)
is_square(M) && is_square(N) || error("Not a square matrix in trace")
d = zero(base_ring(M))
for i = 1:nrows(M)
d += (M[i, :] * N[:, i])[1, 1]
d += (M[i:i, :] * N[:, i:i])[1, 1]
end
return d
end
Expand Down
5 changes: 5 additions & 0 deletions src/generic/GenericTypes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1096,6 +1096,11 @@ struct MatSpaceView{T <: NCRingElement, V, W} <: Mat{T}
base_ring::NCRing
end

struct MatSpaceVecView{T <: NCRingElement, V, W} <: AbstractVector{T}
entries::SubArray{T, 1, Matrix{T}, V, W}
base_ring::NCRing
end

###############################################################################
#
# MatAlgebra / MatAlgElem
Expand Down
19 changes: 18 additions & 1 deletion src/generic/Matrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,18 @@ function deepcopy_internal(d::MatSpaceView{T}, dict::IdDict) where T <: NCRingEl
return MatSpaceView(deepcopy_internal(d.entries, dict), d.base_ring)
end

function Base.view(M::Mat{T}, rows::AbstractUnitRange{Int}, cols::AbstractUnitRange{Int}) where T <: NCRingElement
function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement
return MatSpaceView(view(M.entries, rows, cols), M.base_ring)
end

function Base.view(M::Mat{T}, rows::Int, cols::Union{Colon, AbstractVector{Int}}) where T <: NCRingElement
return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring)
end

function Base.view(M::Mat{T}, rows::Union{Colon, AbstractVector{Int}}, cols::Int) where T <: NCRingElement
return MatSpaceVecView(view(M.entries, rows, cols), M.base_ring)
end

################################################################################
#
# Size, axes and is_square
Expand Down Expand Up @@ -228,3 +236,12 @@ function AbstractAlgebra.mul!(A::Mat{T}, B::Mat{T}, C::Mat{T}, f::Bool = false)
return A
end

Base.length(V::MatSpaceVecView) = length(V.entries)

Base.getindex(V::MatSpaceVecView, i::Int) = V.entries[i]

Base.setindex!(V::MatSpaceVecView{T}, z::T, i::Int) where {T} = (V.entries[i] = z)

Base.setindex!(V::MatSpaceVecView, z::RingElement, i::Int) = setindex!(V.entries, V.base_ring(z), i)

Base.size(V::MatSpaceVecView) = (length(V.entries), )
42 changes: 34 additions & 8 deletions test/generic/Matrix-test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -883,7 +883,7 @@ end
@test ndims(A) == 2

@test size(A) == (2, 3)
@test eachindex(A) == CartesianIndices((2, 3))
#@test eachindex(A) == CartesianIndices((2, 3))

# cartesian indexing
for i in eachindex(A)
Expand All @@ -899,9 +899,9 @@ end
@test_throws BoundsError A[CartesianIndex(rand(3:99), 1)]

# iteration
for (i, x) in enumerate(A)
@test A[Tuple(CartesianIndices(size(A))[i])...] == x
end
#for (i, x) in enumerate(A)
# @test A[Tuple(CartesianIndices(size(A))[i])...] == x
#end

AC = collect(A)
@test size(AC) == size(A)
Expand Down Expand Up @@ -1420,8 +1420,8 @@ end
Q = inv(P)

PA = P*A
@test PA == reduce(vcat, [A[Q[i], :] for i in 1:nrows(A)])
@test PA == reduce(vcat, A[Q[i], :] for i in 1:nrows(A))
@test PA == reduce(vcat, [A[Q[i]:Q[i], :] for i in 1:nrows(A)])
@test PA == reduce(vcat, A[Q[i]:Q[i], :] for i in 1:nrows(A))
@test PA == S(reduce(vcat, A.entries[Q[i], :] for i in 1:nrows(A)))
@test A == Q*(P*A)
end
Expand Down Expand Up @@ -4022,20 +4022,34 @@ end
@test fflu(N3) == fflu(M) # tests that deepcopy is correct
@test M2 == M

for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ]
for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ]
v = @view M[i,j]
@test v isa Generic.MatSpaceView
@test M[i,j] == v
end

M2 = deepcopy(M)
M3 = @view M2[2, 1:2]
@test length(M3) == 2
@test M3 == [2, 3]
M3[2] = 5
@test M2 == ZZ[1 2 3; 2 5 4; 3 4 5]

M2 = deepcopy(M)
M3 = @view M2[1:3, 3]
@test length(M3) == 3
@test M3 == [3, 4, 5]
M3[1] = 10
@test M2 == ZZ[1 2 10; 2 3 4; 3 4 5]

# Test views over noncommutative ring
R = MatrixAlgebra(ZZ, 2)

S = matrix_space(R, 4, 4)

M = rand(S, -10:10)

for i in [ 1, 1:2, : ], j in [ 1, 1:2, : ]
for i in [ 1:1, 1:2, : ], j in [ 1:1, 1:2, : ]
v = @view M[i,j]
@test v isa Generic.MatSpaceView
@test M[i,j] == v
Expand Down Expand Up @@ -4288,3 +4302,15 @@ end
@test base_ring(L) === QQ
@test L == N * change_base_ring(QQ, M)
end

@testset "Generic.indices" begin
A = ZZ[1 2 3; 4 5 6]
elts = elem_type(ZZ)[]
for i in eachindex(A)
push!(elts, A[i])
end
@test elts == [1, 2, 3, 4, 5, 6]
C = CartesianIndices(A)
L = LinearIndices(C)
@test all(L[c] == k for (k, c) in enumerate(C))
end

0 comments on commit f79bdb4

Please sign in to comment.