Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

hvncat: Extended docstring, reorganized helper methods, more consistent error throwing #41195

Merged
merged 1 commit into from
Jun 14, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 42 additions & 19 deletions base/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,12 @@ cat_size(A::AbstractArray) = size(A)
cat_size(A, d) = 1
cat_size(A::AbstractArray, d) = size(A, d)

cat_length(::Any) = 1
cat_length(a::AbstractArray) = length(a)

cat_ndims(a) = 0
cat_ndims(a::AbstractArray) = ndims(a)

cat_indices(A, d) = OneTo(1)
cat_indices(A::AbstractArray, d) = axes(A, d)

Expand Down Expand Up @@ -2034,22 +2040,25 @@ function typed_hvcat(::Type{T}, rows::Tuple{Vararg{Int}}, as...) where T
T[rs...;]
end

# nd concatenation
## N-dimensional concatenation ##

"""
hvncat(dim::Int, row_first, values...)
hvncat(dims::Tuple{Vararg{Int}}, row_first, values...)
hvncat(shape::Tuple{Vararg{Tuple}}, row_first, values...)

Horizontal, vertical, and n-dimensional concatenation of many `values` in one call.

This function is called
for block matrix syntax. The first argument either specifies the shape of the concatenation,
similar to `hvcat`, as a tuple of tuples, or the dimensions that specify the key number of
elements along each axis, and is used to determine the output dimensions. The `dims` form
is more performant, and is used by default when the concatenation operation has the same
number of elements along each axis (e.g., [a b; c d;;; e f ; g h]). The `shape` form is used
when the number of elements along each axis is unbalanced (e.g., [a b ; c]). Unbalanced
syntax needs additional validation overhead.
This function is called for block matrix syntax. The first argument either specifies the
shape of the concatenation, similar to `hvcat`, as a tuple of tuples, or the dimensions that
specify the key number of elements along each axis, and is used to determine the output
dimensions. The `dims` form is more performant, and is used by default when the concatenation
operation has the same number of elements along each axis (e.g., [a b; c d;;; e f ; g h]).
The `shape` form is used when the number of elements along each axis is unbalanced
(e.g., [a b ; c]). Unbalanced syntax needs additional validation overhead. The `dim` form
is an optimization for concatenation along just one dimension. `row_first` indicates how
`values` are ordered. The meaning of the first and second elements of `shape` are also
swapped based on `row_first`.

# Examples
```jldoctest
Expand Down Expand Up @@ -2097,6 +2106,24 @@ julia> hvncat(((3, 3), (3, 3), (6,)), true, a, b, c, d, e, f)
[:, :, 2] =
4 5 6
```

# Examples for construction of the arguments:
[a b c ; d e f ;;;
g h i ; j k l ;;;
m n o ; p q r ;;;
s t u ; v w x]
=> dims = (2, 3, 4)

[a b ; c ;;; d ;;;;]
___ _ _
2 1 1 = elements in each row (2, 1, 1)
_______ _
3 1 = elements in each column (3, 1)
_____________
4 = elements in each 3d slice (4,)
_____________
4 = elements in each 4d slice (4,)
=> shape = ((2, 1, 1), (3, 1), (4,), (4,)) with `rowfirst` = true
"""
hvncat(::Tuple{}, ::Bool) = []
hvncat(::Tuple{}, ::Bool, xs...) = []
Expand Down Expand Up @@ -2188,9 +2215,6 @@ function _typed_hvncat(::Type{T}, ::Val{N}, as::AbstractArray...) where {T, N}
return A
end

cat_ndims(a) = 0
cat_ndims(a::AbstractArray) = ndims(a)

function _typed_hvncat(::Type{T}, ::Val{N}, as...) where {T, N}
# optimization for scalars and 1-length arrays that can be concatenated by copying them linearly
# into the destination
Expand Down Expand Up @@ -2257,12 +2281,12 @@ function _typed_hvncat(::Type{T}, dims::Tuple{Vararg{Int, N}}, row_first::Bool,
elseif currentdims[d] < outdims[d] # dimension in progress
break
else # exceeded dimension
ArgumentError("argument $i has too many elements along axis $d") |> throw
throw(ArgumentError("argument $i has too many elements along axis $d"))
end
end
end
elseif currentdims[d1] > outdims[d1] # exceeded dimension
ArgumentError("argument $i has too many elements along axis $d1") |> throw
throw(ArgumentError("argument $i has too many elements along axis $d1"))
end
end

Expand Down Expand Up @@ -2308,7 +2332,8 @@ function _typed_hvncat(::Type{T}, shape::Tuple{Vararg{Tuple, N}}, row_first::Boo
if d == 1 || i == 1 || wasstartblock
currentdims[d] += dsize
elseif dsize != cat_size(as[i - 1], ad)
ArgumentError("argument $i has a mismatched number of elements along axis $ad; expected $(cat_size(as[i - 1], ad)), got $dsize") |> throw
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(cat_size(as[i - 1], ad)), got $dsize"""))
Comment on lines +2335 to +2336
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love seeing this of course! :)

end

wasstartblock = blockcounts[d] == 1 # remember for next dimension
Expand All @@ -2318,7 +2343,8 @@ function _typed_hvncat(::Type{T}, shape::Tuple{Vararg{Tuple, N}}, row_first::Boo
if outdims[d] == 0
outdims[d] = currentdims[d]
elseif outdims[d] != currentdims[d]
ArgumentError("argument $i has a mismatched number of elements along axis $ad; expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize") |> throw
throw(ArgumentError("""argument $i has a mismatched number of elements along axis $ad; \
expected $(abs(outdims[d] - (currentdims[d] - dsize))), got $dsize"""))
end
currentdims[d] = 0
blockcounts[d] = 0
Expand Down Expand Up @@ -2382,9 +2408,6 @@ end
Ai
end

cat_length(a::AbstractArray) = length(a)
cat_length(::Any) = 1

## Reductions and accumulates ##

function isequal(A::AbstractArray, B::AbstractArray)
Expand Down