diff --git a/NEWS.md b/NEWS.md index f871424e47337..421093b6e0681 100644 --- a/NEWS.md +++ b/NEWS.md @@ -68,6 +68,9 @@ New library functions inspecting which function `f` was originally wrapped. ([#42717]) * New `pkgversion(m::Module)` function to get the version of the package that loaded a given module, similar to `pkgdir(m::Module)`. ([#45607]) +* New function `stack(x)` which generalises `reduce(hcat, x::Vector{<:Vector})` to any dimensionality, + and allows any iterators of iterators. Method `stack(f, x)` generalises `mapreduce(f, hcat, x)` and + is efficient. ([#43334]) Library changes --------------- diff --git a/base/abstractarray.jl b/base/abstractarray.jl index 1690aa4a9e56f..7663df3425908 100644 --- a/base/abstractarray.jl +++ b/base/abstractarray.jl @@ -2605,6 +2605,236 @@ end Ai end +""" + stack(iter; [dims]) + +Combine a collection of arrays (or other iterable objects) of equal size +into one larger array, by arranging them along one or more new dimensions. + +By default the axes of the elements are placed first, +giving `size(result) = (size(first(iter))..., size(iter)...)`. +This has the same order of elements as [`Iterators.flatten`](@ref)`(iter)`. + +With keyword `dims::Integer`, instead the `i`th element of `iter` becomes the slice +[`selectdim`](@ref)`(result, dims, i)`, so that `size(result, dims) == length(iter)`. +In this case `stack` reverses the action of [`eachslice`](@ref) with the same `dims`. + +The various [`cat`](@ref) functions also combine arrays. However, these all +extend the arrays' existing (possibly trivial) dimensions, rather than placing +the arrays along new dimensions. +They also accept arrays as separate arguments, rather than a single collection. + +!!! compat "Julia 1.9" + This function requires at least Julia 1.9. + +# Examples +```jldoctest +julia> vecs = (1:2, [30, 40], Float32[500, 600]); + +julia> mat = stack(vecs) +2×3 Matrix{Float32}: + 1.0 30.0 500.0 + 2.0 40.0 600.0 + +julia> mat == hcat(vecs...) == reduce(hcat, collect(vecs)) +true + +julia> vec(mat) == vcat(vecs...) == reduce(vcat, collect(vecs)) +true + +julia> stack(zip(1:4, 10:99)) # accepts any iterators of iterators +2×4 Matrix{Int64}: + 1 2 3 4 + 10 11 12 13 + +julia> vec(ans) == collect(Iterators.flatten(zip(1:4, 10:99))) +true + +julia> stack(vecs; dims=1) # unlike any cat function, 1st axis of vecs[1] is 2nd axis of result +3×2 Matrix{Float32}: + 1.0 2.0 + 30.0 40.0 + 500.0 600.0 + +julia> x = rand(3,4); + +julia> x == stack(eachcol(x)) == stack(eachrow(x), dims=1) # inverse of eachslice +true +``` + +Higher-dimensional examples: + +```jldoctest +julia> A = rand(5, 7, 11); + +julia> E = eachslice(A, dims=2); # a vector of matrices + +julia> (element = size(first(E)), container = size(E)) +(element = (5, 11), container = (7,)) + +julia> stack(E) |> size +(5, 11, 7) + +julia> stack(E) == stack(E; dims=3) == cat(E...; dims=3) +true + +julia> A == stack(E; dims=2) +true + +julia> M = (fill(10i+j, 2, 3) for i in 1:5, j in 1:7); + +julia> (element = size(first(M)), container = size(M)) +(element = (2, 3), container = (5, 7)) + +julia> stack(M) |> size # keeps all dimensions +(2, 3, 5, 7) + +julia> stack(M; dims=1) |> size # vec(container) along dims=1 +(35, 2, 3) + +julia> hvcat(5, M...) |> size # hvcat puts matrices next to each other +(14, 15) +``` +""" +stack(iter; dims=:) = _stack(dims, iter) + +""" + stack(f, args...; [dims]) + +Apply a function to each element of a collection, and `stack` the result. +Or to several collections, [`zip`](@ref)ped together. + +The function should return arrays (or tuples, or other iterators) all of the same size. +These become slices of the result, each separated along `dims` (if given) or by default +along the last dimensions. + +See also [`mapslices`](@ref), [`eachcol`](@ref). + +# Examples +```jldoctest +julia> stack(c -> (c, c-32), "julia") +2×5 Matrix{Char}: + 'j' 'u' 'l' 'i' 'a' + 'J' 'U' 'L' 'I' 'A' + +julia> stack(eachrow([1 2 3; 4 5 6]), (10, 100); dims=1) do row, n + vcat(row, row .* n, row ./ n) + end +2×9 Matrix{Float64}: + 1.0 2.0 3.0 10.0 20.0 30.0 0.1 0.2 0.3 + 4.0 5.0 6.0 400.0 500.0 600.0 0.04 0.05 0.06 +``` +""" +stack(f, iter; dims=:) = _stack(dims, f(x) for x in iter) +stack(f, xs, yzs...; dims=:) = _stack(dims, f(xy...) for xy in zip(xs, yzs...)) + +_stack(dims::Union{Integer, Colon}, iter) = _stack(dims, IteratorSize(iter), iter) + +_stack(dims, ::IteratorSize, iter) = _stack(dims, collect(iter)) + +function _stack(dims, ::Union{HasShape, HasLength}, iter) + S = @default_eltype iter + T = S != Union{} ? eltype(S) : Any # Union{} occurs for e.g. stack(1,2), postpone the error + if isconcretetype(T) + _typed_stack(dims, T, S, iter) + else # Need to look inside, but shouldn't run an expensive iterator twice: + array = iter isa Union{Tuple, AbstractArray} ? iter : collect(iter) + isempty(array) && return _empty_stack(dims, T, S, iter) + T2 = mapreduce(eltype, promote_type, array) + _typed_stack(dims, T2, eltype(array), array) + end +end + +function _typed_stack(::Colon, ::Type{T}, ::Type{S}, A, Aax=_iterator_axes(A)) where {T, S} + xit = iterate(A) + nothing === xit && return _empty_stack(:, T, S, A) + x1, _ = xit + ax1 = _iterator_axes(x1) + B = similar(_ensure_array(x1), T, ax1..., Aax...) + off = firstindex(B) + len = length(x1) + while xit !== nothing + x, state = xit + _stack_size_check(x, ax1) + copyto!(B, off, x) + off += len + xit = iterate(A, state) + end + B +end + +_iterator_axes(x) = _iterator_axes(x, IteratorSize(x)) +_iterator_axes(x, ::HasLength) = (OneTo(length(x)),) +_iterator_axes(x, ::IteratorSize) = axes(x) + +# For some dims values, stack(A; dims) == stack(vec(A)), and the : path will be faster +_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} = + _typed_stack(dims, T, S, IteratorSize(S), A) +_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasLength, A) where {T,S} = + _typed_stack(dims, T, S, HasShape{1}(), A) +function _typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::HasShape{N}, A) where {T,S,N} + if dims == N+1 + _typed_stack(:, T, S, A, (_vec_axis(A),)) + else + _dim_stack(dims, T, S, A) + end +end +_typed_stack(dims::Integer, ::Type{T}, ::Type{S}, ::IteratorSize, A) where {T,S} = + _dim_stack(dims, T, S, A) + +_vec_axis(A, ax=_iterator_axes(A)) = length(ax) == 1 ? only(ax) : OneTo(prod(length, ax; init=1)) + +@constprop :aggressive function _dim_stack(dims::Integer, ::Type{T}, ::Type{S}, A) where {T,S} + xit = Iterators.peel(A) + nothing === xit && return _empty_stack(dims, T, S, A) + x1, xrest = xit + ax1 = _iterator_axes(x1) + N1 = length(ax1)+1 + dims in 1:N1 || throw(ArgumentError(LazyString("cannot stack slices ndims(x) = ", N1-1, " along dims = ", dims))) + + newaxis = _vec_axis(A) + outax = ntuple(d -> d==dims ? newaxis : ax1[d - (d>dims)], N1) + B = similar(_ensure_array(x1), T, outax...) + + if dims == 1 + _dim_stack!(Val(1), B, x1, xrest) + elseif dims == 2 + _dim_stack!(Val(2), B, x1, xrest) + else + _dim_stack!(Val(dims), B, x1, xrest) + end + B +end + +function _dim_stack!(::Val{dims}, B::AbstractArray, x1, xrest) where {dims} + before = ntuple(d -> Colon(), dims - 1) + after = ntuple(d -> Colon(), ndims(B) - dims) + + i = firstindex(B, dims) + copyto!(view(B, before..., i, after...), x1) + + for x in xrest + _stack_size_check(x, _iterator_axes(x1)) + i += 1 + @inbounds copyto!(view(B, before..., i, after...), x) + end +end + +@inline function _stack_size_check(x, ax1::Tuple) + if _iterator_axes(x) != ax1 + uax1 = map(UnitRange, ax1) + uaxN = map(UnitRange, axes(x)) + throw(DimensionMismatch( + LazyString("stack expects uniform slices, got axes(x) == ", uaxN, " while first had ", uax1))) + end +end + +_ensure_array(x::AbstractArray) = x +_ensure_array(x) = 1:0 # passed to similar, makes stack's output an Array + +_empty_stack(_...) = throw(ArgumentError("`stack` on an empty collection is not allowed")) + + ## Reductions and accumulates ## function isequal(A::AbstractArray, B::AbstractArray) diff --git a/base/exports.jl b/base/exports.jl index 9c5601a8740dd..f64c3b2913260 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -445,6 +445,7 @@ export sortperm!, sortslices, dropdims, + stack, step, stride, strides, diff --git a/base/iterators.jl b/base/iterators.jl index 0184ab51323b4..41043f1cc9f0a 100644 --- a/base/iterators.jl +++ b/base/iterators.jl @@ -1199,7 +1199,7 @@ See also [`Iterators.flatten`](@ref), [`Iterators.map`](@ref). # Examples ```jldoctest -julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect +julia> Iterators.flatmap(n -> -n:2:n, 1:3) |> collect 9-element Vector{Int64}: -1 1 @@ -1210,6 +1210,20 @@ julia> Iterators.flatmap(n->-n:2:n, 1:3) |> collect -1 1 3 + +julia> stack(n -> -n:2:n, 1:3) +ERROR: DimensionMismatch: stack expects uniform slices, got axes(x) == (1:3,) while first had (1:2,) +[...] + +julia> Iterators.flatmap(n -> (-n, 10n), 1:2) |> collect +4-element Vector{Int64}: + -1 + 10 + -2 + 20 + +julia> ans == vec(stack(n -> (-n, 10n), 1:2)) +true ``` """ flatmap(f, c...) = flatten(map(f, c...)) diff --git a/doc/src/base/arrays.md b/doc/src/base/arrays.md index 853e4c7a4ec1b..6585f98360585 100644 --- a/doc/src/base/arrays.md +++ b/doc/src/base/arrays.md @@ -145,6 +145,7 @@ Base.vcat Base.hcat Base.hvcat Base.hvncat +Base.stack Base.vect Base.circshift Base.circshift! diff --git a/test/abstractarray.jl b/test/abstractarray.jl index 1ca91cad77d61..5e4612314e8d4 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -1553,6 +1553,132 @@ using Base: typed_hvncat @test [["A";"B"];;"C";"D"] == ["A" "C"; "B" "D"] end +@testset "stack" begin + # Basics + for args in ([[1, 2]], [1:2, 3:4], [[1 2; 3 4], [5 6; 7 8]], + AbstractVector[1:2, [3.5, 4.5]], Vector[[1,2], [3im, 4im]], + [[1:2, 3:4], [5:6, 7:8]], [fill(1), fill(2)]) + X = stack(args) + Y = cat(args...; dims=ndims(args[1])+1) + @test X == Y + @test typeof(X) === typeof(Y) + + X2 = stack(x for x in args) + @test X2 == Y + @test typeof(X2) === typeof(Y) + + X3 = stack(x for x in args if true) + @test X3 == Y + @test typeof(X3) === typeof(Y) + + if isconcretetype(eltype(args)) + @inferred stack(args) + @inferred stack(x for x in args) + end + end + + # Higher dims + @test size(stack([rand(2,3) for _ in 1:4, _ in 1:5])) == (2,3,4,5) + @test size(stack(rand(2,3) for _ in 1:4, _ in 1:5)) == (2,3,4,5) + @test size(stack(rand(2,3) for _ in 1:4, _ in 1:5 if true)) == (2, 3, 20) + @test size(stack([rand(2,3) for _ in 1:4, _ in 1:5]; dims=1)) == (20, 2, 3) + @test size(stack(rand(2,3) for _ in 1:4, _ in 1:5; dims=2)) == (2, 20, 3) + + # Tuples + @test stack([(1,2), (3,4)]) == [1 3; 2 4] + @test stack(((1,2), (3,4))) == [1 3; 2 4] + @test stack(Any[(1,2), (3,4)]) == [1 3; 2 4] + @test stack([(1,2), (3,4)]; dims=1) == [1 2; 3 4] + @test stack(((1,2), (3,4)); dims=1) == [1 2; 3 4] + @test stack(Any[(1,2), (3,4)]; dims=1) == [1 2; 3 4] + @test size(@inferred stack(Iterators.product(1:3, 1:4))) == (2,3,4) + @test @inferred(stack([('a', 'b'), ('c', 'd')])) == ['a' 'c'; 'b' 'd'] + @test @inferred(stack([(1,2+3im), (4, 5+6im)])) isa Matrix{Number} + + # stack(f, iter) + @test @inferred(stack(x -> [x, 2x], 3:5)) == [3 4 5; 6 8 10] + @test @inferred(stack(x -> x*x'/2, [1:2, 3:4])) == [0.5 1.0; 1.0 2.0;;; 4.5 6.0; 6.0 8.0] + @test @inferred(stack(*, [1:2, 3:4], 5:6)) == [5 18; 10 24] + + # Iterators + @test stack([(a=1,b=2), (a=3,b=4)]) == [1 3; 2 4] + @test stack([(a=1,b=2), (c=3,d=4)]) == [1 3; 2 4] + @test stack([(a=1,b=2), (c=3,d=4)]; dims=1) == [1 2; 3 4] + @test stack([(a=1,b=2), (c=3,d=4)]; dims=2) == [1 3; 2 4] + @test stack((x/y for x in 1:3) for y in 4:5) == (1:3) ./ (4:5)' + @test stack((x/y for x in 1:3) for y in 4:5; dims=1) == (1:3)' ./ (4:5) + + # Exotic + ips = ((Iterators.product([i,i^2], [2i,3i,4i], 1:4)) for i in 1:5) + @test size(stack(ips)) == (2, 3, 4, 5) + @test stack(ips) == cat(collect.(ips)...; dims=4) + ips_cat2 = cat(reshape.(collect.(ips), Ref((2,1,3,4)))...; dims=2) + @test stack(ips; dims=2) == ips_cat2 + @test stack(collect.(ips); dims=2) == ips_cat2 + ips_cat3 = cat(reshape.(collect.(ips), Ref((2,3,1,4)))...; dims=3) + @test stack(ips; dims=3) == ips_cat3 # path for non-array accumulation on non-final dims + @test stack(collect, ips; dims=3) == ips_cat3 # ... and for array accumulation + @test stack(collect.(ips); dims=3) == ips_cat3 + + # Trivial, because numbers are iterable: + @test stack(abs2, 1:3) == [1, 4, 9] == collect(Iterators.flatten(abs2(x) for x in 1:3)) + + # Allocation tests + xv = [rand(10) for _ in 1:100] + xt = Tuple.(xv) + for dims in (1, 2, :) + @test stack(xv; dims) == stack(xt; dims) + @test_skip 9000 > @allocated stack(xv; dims) + @test_skip 9000 > @allocated stack(xt; dims) + end + xr = (reshape(1:1000,10,10,10) for _ = 1:1000) + for dims in (1, 2, 3, :) + stack(xr; dims) + @test_skip 8.1e6 > @allocated stack(xr; dims) + end + + # Mismatched sizes + @test_throws DimensionMismatch stack([1:2, 1:3]) + @test_throws DimensionMismatch stack([1:2, 1:3]; dims=1) + @test_throws DimensionMismatch stack([1:2, 1:3]; dims=2) + @test_throws DimensionMismatch stack([(1,2), (3,4,5)]) + @test_throws DimensionMismatch stack([(1,2), (3,4,5)]; dims=1) + @test_throws DimensionMismatch stack(x for x in [1:2, 1:3]) + @test_throws DimensionMismatch stack([[5 6; 7 8], [1, 2, 3, 4]]) + @test_throws DimensionMismatch stack([[5 6; 7 8], [1, 2, 3, 4]]; dims=1) + @test_throws DimensionMismatch stack(x for x in [[5 6; 7 8], [1, 2, 3, 4]]) + # Inner iterator of unknown length + @test_throws MethodError stack((x for x in 1:3 if true) for _ in 1:4) + @test_throws MethodError stack((x for x in 1:3 if true) for _ in 1:4; dims=1) + + @test_throws ArgumentError stack([1:3, 4:6]; dims=0) + @test_throws ArgumentError stack([1:3, 4:6]; dims=3) + @test_throws ArgumentError stack(abs2, 1:3; dims=2) + + # Empty + @test_throws ArgumentError stack(()) + @test_throws ArgumentError stack([]) + @test_throws ArgumentError stack(x for x in 1:3 if false) +end + +@testset "tests from PR 31644" begin + v_v_same = [rand(128) for ii in 1:100] + v_v_diff = Any[rand(128), rand(Float32,128), rand(Int, 128)] + v_v_diff_typed = Union{Vector{Float64},Vector{Float32},Vector{Int}}[rand(128), rand(Float32,128), rand(Int, 128)] + for v_v in (v_v_same, v_v_diff, v_v_diff_typed) + # Cover all combinations of iterator traits. + g_v = (x for x in v_v) + f_g_v = Iterators.filter(x->true, g_v) + f_v_v = Iterators.filter(x->true, v_v); + hcat_expected = hcat(v_v...) + vcat_expected = vcat(v_v...) + @testset "$(typeof(data))" for data in (v_v, g_v, f_g_v, f_v_v) + @test stack(data) == hcat_expected + @test vec(stack(data)) == vcat_expected + end + end +end + @testset "keepat!" begin a = [1:6;] @test a === keepat!(a, 1:5) diff --git a/test/offsetarray.jl b/test/offsetarray.jl index bf5beab5c3437..afd54ee576c16 100644 --- a/test/offsetarray.jl +++ b/test/offsetarray.jl @@ -810,6 +810,22 @@ end @test reshape(a, (:,)) === a end +@testset "stack" begin + nought = OffsetArray([0, 0.1, 0.01], 0:2) + ten = OffsetArray([1,10,100,1000], 10:13) + + @test stack(ten) == ten + @test stack(ten .+ nought') == ten .+ nought' + @test stack(x^2 for x in ten) == ten.^2 + + @test axes(stack(nought for _ in ten)) == (0:2, 10:13) + @test axes(stack([nought for _ in ten])) == (0:2, 10:13) + @test axes(stack(nought for _ in ten; dims=1)) == (10:13, 0:2) + @test axes(stack((x, x^2) for x in nought)) == (1:2, 0:2) + @test axes(stack(x -> x[end-1:end], ten for _ in nought, _ in nought)) == (1:2, 0:2, 0:2) + @test axes(stack([ten[end-1:end] for _ in nought, _ in nought])) == (1:2, 0:2, 0:2) +end + @testset "issue #41630: replace_ref_begin_end!/@view on offset-like arrays" begin x = OffsetArray([1 2; 3 4], -10:-9, 9:10) # 2×2 OffsetArray{...} with indices -10:-9×9:10