Skip to content

Commit

Permalink
overload permutedims for OneElement matrices (#368)
Browse files Browse the repository at this point in the history
* overload for OneElement matrices

* add unit test

* more general implementations and unit tests

* make adjustments

* add unit test and style
  • Loading branch information
max-vassili3v authored Aug 16, 2024
1 parent f659d9f commit f0f7618
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/FillArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import Base: size, getindex, setindex!, IndexStyle, checkbounds, convert,
any, all, axes, isone, iszero, iterate, unique, allunique, permutedims, inv,
copy, vec, setindex!, count, ==, reshape, map, zero,
show, view, in, mapreduce, one, reverse, promote_op, promote_rule, repeat,
parent, similar, issorted, add_sum, accumulate, OneTo
parent, similar, issorted, add_sum, accumulate, OneTo, permutedims

import LinearAlgebra: rank, svdvals!, tril, triu, tril!, triu!, diag, transpose, adjoint, fill!,
dot, norm2, norm1, normInf, normMinusInf, normp, lmul!, rmul!, diagzero, AdjointAbsVec, TransposeAbsVec,
Expand Down
6 changes: 6 additions & 0 deletions src/oneelement.jl
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,12 @@ function Base.reshape(A::OneElement, shape::Tuple{Vararg{Int}})
OneElement(A.val, Tuple(newcartind), shape)
end

#permute
_permute(x, p) = ntuple(i -> x[p[i]], length(x))
permutedims(o::OneElementMatrix) = OneElement(o.val, reverse(o.ind), reverse(o.axes))
permutedims(o::OneElementVector) = reshape(o, (1, length(o)))
permutedims(o::OneElement, dims) = OneElement(o.val, _permute(o.ind, dims), _permute(o.axes, dims))

# show
_maybesize(t::Tuple{Base.OneTo{Int}, Vararg{Base.OneTo{Int}}}) = size.(t,1)
_maybesize(t) = t
Expand Down
11 changes: 11 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2667,6 +2667,17 @@ end
end
end

@testset "permutedims" begin
v = OneElement(1, (2, 3), (2, 5))
@test permutedims(v) === OneElement(1, (3, 2), (5, 2))
w = OneElement(1, 3, 5)
@test permutedims(w) === OneElement(1, (1, 3), (1, 5))
x = OneElement(1, (1, 2, 3), (4, 5, 6))
@test permutedims(x, (1, 2, 3)) === x
@test permutedims(x, (3, 2, 1)) === OneElement(1, (3, 2, 1), (6, 5, 4))
@test permutedims(x, [2, 3, 1]) === OneElement(1, (2, 3, 1), (5, 6, 4))
@test_throws BoundsError permutedims(x, (2, 1))
end
@testset "show" begin
B = OneElement(2, (1, 2), (3, 4))
@test repr(B) == "OneElement(2, (1, 2), (3, 4))"
Expand Down

0 comments on commit f0f7618

Please sign in to comment.