diff --git a/base/array.jl b/base/array.jl index 88c235fd559dc..d14ce9ed85942 100644 --- a/base/array.jl +++ b/base/array.jl @@ -2467,3 +2467,167 @@ end intersect(itr, itrs...) = _shrink(intersect!, itr, itrs) setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs) + +struct Rows{Row, Dimensions, Columns} <: AbstractArray{Row, Dimensions} + columns::Columns +end + +function compare_axes(reference_axes, item) + axes(item) == reference_axes +end +function same_axes(first_column, rest...) + reduce_unrolled(&, partial_map(compare_axes, axes(first_column), rest)...) +end +function same_axes() + true +end + +@propagate_inbounds function Rows{Row, Dimension}(columns::Columns) where {Row, Dimension, Columns} + @boundscheck if !same_axes(columns...) + throw(DimensionMismatch("All arguments to `Rows` must have the same axes")) + end + Rows{Row, Dimension, Columns}(columns) +end + +dummy_column(columns) = first(columns) +dummy_column(::Tuple{}) = 1:0 + +@propagate_inbounds Rows(columns) = Rows{ + Tuple{map_unrolled(eltype, columns)...}, + ndims(dummy_column(columns)) +}(columns) + +parent(rows::Rows) = dummy_column(rows.columns) + +get_columns(rows::Rows) = rows.columns + +axes(rows::Rows, dimensions...) = axes(parent(rows), dimensions...) +size(rows::Rows, dimensions...) = size(parent(rows), dimensions...) + +@propagate_inbounds column_getindex(an_index, column) = column[an_index...] +@propagate_inbounds getindex(rows::Rows, an_index...) = partial_map( + column_getindex, + an_index, + rows.columns +) + +@propagate_inbounds function column_setindex!(an_index, column, value) + column[an_index...] = value +end +@propagate_inbounds function setindex!(rows::Rows, row, an_index...) + partial_map(column_setindex!, an_index, rows.columns, row) +end + +function push!(rows::Rows, row) + map_unrolled(push!, rows.columns, row) +end + +val_fieldtypes(something) = () +@pure val_fieldtypes(a_type::DataType) = + if a_type.abstract || (a_type.name == Tuple.name && isvatuple(a_type)) + () + else + map_unrolled(Val, (a_type.types...,)) + end + +similar_val((model, dimensions), ::Val{Value}) where {Value} = + similar(model, Value, dimensions) +similar(rows::Rows, ::Type{ARow}, dimensions::Dims) where {ARow} = + @inbounds Rows(partial_map( + similar_val, + (parent(rows), dimensions), + val_fieldtypes(ARow) + )) + +empty(column::Rows{OldRow}, ::Type{NewRow} = OldRow) where {OldRow, NewRow} = + similar(column, NewRow) + +function widen_column(::HasLength, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element} + @inbounds column[an_index] = item + column +end +widen_column(::HasLength, new_length, an_index, name, column::Array, item) = + setindex_widen_up_to(column, item, an_index) + +function widen_column(::SizeUnknown, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element} + push!(column, item) + column +end +widen_column(::SizeUnknown, new_length, an_index, column::Array, item) = + push_widen(column, item) + +function widen_column(iterator_size, new_length, an_index, ::Missing, item::Item) where {Item} + new_column = Array{Union{Missing, Item}}(missing, new_length) + @inbounds new_column[an_index] = item + new_column +end +widen_column(iterator_size, new_length, an_index, ::Missing, ::Missing) = + Array{Missing}(missing, new_length) + +widen_column(fixeds, variables) = widen_column(fixeds..., variables...) + +get_new_length(::SizeUnknown, rows, an_index) = an_index +get_new_length(::HasLength, rows, an_index) = length(rows) + +zip_first_missing() = () +zip_first_missing(head, the_tail...) = (missing, head), zip_first_missing(the_tail...)... +zip_second_missing() = () +zip_second_missing(head, the_tail...) = (head, missing), zip_second_missing(the_tail...)... +zip_missing(::Tuple{}, ::Tuple{}) = () +zip_missing(::Tuple{}, longer) = zip_first_missing(longer...) +zip_missing(longer, ::Tuple{}) = zip_second_missing(longer...) +zip_missing(tuple1, tuple2) = + (first(tuple1), first(tuple2)), + zip_missing(tail(tuple1), tail(tuple2))... + +function widen_columns(iterator_size, rows, row, an_index = length(rows) + 1) + columns = rows.columns + @inbounds Rows(partial_map( + widen_column, + (iterator_size, get_new_length(iterator_size, rows, an_index), an_index), + zip_missing(rows.columns, row) + )) +end + +push_widen(rows::Rows, row) = widen_columns(SizeUnknown(), rows, row) +setindex_widen_up_to(rows::Rows, row, an_index) = + widen_columns(HasLength(), rows, row, an_index) + +""" + unzip(rows) + +Collect into columns. Always eager, see [`to_columns`](@ref) for a lazy version. + +```jldoctest +julia> using Base: Generator + +julia> using Test: @inferred + +julia> stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0); + +julia> @inferred unzip(Generator(stable, 1:4)) +([1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0]) + +julia> unstable(x) = + if x == 2 + (x, x + 0.0, x, x + 0.0) + else + (x, x + 0.0) + end; + +julia> unzip(Generator(unstable, 1:3)) +([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing]) + +julia> unzip(Iterators.filter(row -> true, Generator(unstable, 1:3))) +([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing]) + +julia> unzip(Iterators.filter(row -> false, Generator(unstable, 1:4))) +() +``` +""" +unzip(rows) = get_columns(_collect( + (@inbounds Rows(())), + rows, + IteratorEltype(rows), + IteratorSize(rows) +)) diff --git a/base/exports.jl b/base/exports.jl index 49063720c14c4..548e47bc2e0cf 100644 --- a/base/exports.jl +++ b/base/exports.jl @@ -542,6 +542,7 @@ export ⊋, ∩, ∪, + unzip, # strings ascii, diff --git a/base/tuple.jl b/base/tuple.jl index 7e8c6804bc496..6f711df823d08 100644 --- a/base/tuple.jl +++ b/base/tuple.jl @@ -409,3 +409,30 @@ _tuple_any(f::Function, tf::Bool) = tf Returns an empty tuple, `()`. """ empty(@nospecialize x::Tuple) = () + +# Unrolled operations +# A mini, unexported module for operations unrolled via recursion +map_unrolled(call, variables::Tuple{}) = () +map_unrolled(call, variables) = + call(first(variables)), map_unrolled(call, tail(variables))... + +map_unrolled(call, variables1::Tuple{}, variables2::Tuple{}) = () +map_unrolled(call, variables1, variables2) = + call(first(variables1), first(variables2)), + map_unrolled(call, tail(variables1), tail(variables2))... + +partial_map(call, fixed, variables::Tuple{}) = () +partial_map(call, fixed, variables) = + call(fixed, first(variables)), partial_map(call, fixed, tail(variables))... + +partial_map(call, fixed, variables1::Tuple{}, variables2::Tuple{}) = () +partial_map(call, fixed, variables1, variables2) = + call(fixed, first(variables1), first(variables2)), + partial_map(call, fixed, tail(variables1), tail(variables2))... + +function reduce_unrolled(call, item) + item +end +function reduce_unrolled(call, item1, item2, rest...) + reduce_unrolled(call, call(item1, item2), rest...) +end diff --git a/test/abstractarray.jl b/test/abstractarray.jl index bf6c0f78a30ab..969e5b46d51d7 100644 --- a/test/abstractarray.jl +++ b/test/abstractarray.jl @@ -995,3 +995,43 @@ end @test getindex(x) == getindex(x, CartesianIndex()) == 10 end end + +@testset "Unzip" begin + stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0); + stable_result = @inferred unzip(Base.Generator(stable, 1:2)) + @test stable_result == (1:2, 1:2, 1:2, 1:2, 1:2, 1:2) + @test typeof(stable_result) == + Tuple{ + Vector{Int}, Vector{Float64}, Vector{Int}, Vector{Float64}, + Vector{Int}, Vector{Float64} + } + + unstable(x) = + if x == 2 + (x, x + 0.0, x, x + 0.0) + else + (x, x + 0.0) + + end + + unstable_result_1 = unzip(Base.Generator(unstable, 1:3)) + @test isequal(unstable_result_1, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing])) + @test typeof(unstable_result_1) == + Tuple{ + Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}}, + Vector{Union{Missing, Float64}} + } + + unstable_result_2 = unzip(Iterators.filter(row -> true, Base.Generator(unstable, 1:3))) + @test isequal(unstable_result_2, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing])) + @test typeof(unstable_result_2) == + Tuple{ + Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}}, + Vector{Union{Missing, Float64}} + } + + unstable_result_3 = unzip(Iterators.filter(row -> false, Base.Generator(unstable, 1:4))) + @test unstable_result_3 == () + + @test_throws DimensionMismatch Base.Rows((1:2, 1:3)) +end diff --git a/test/tuple.jl b/test/tuple.jl index 26d5d5f58d953..cd2059f4d45d6 100644 --- a/test/tuple.jl +++ b/test/tuple.jl @@ -451,3 +451,11 @@ end # tuple_type_tail on non-normalized vararg tuple @test Base.tuple_type_tail(Tuple{Vararg{T, 3}} where T<:Real) == Tuple{Vararg{T, 2}} where T<:Real + +@testset "unrolled" begin + @test @inferred Base.map_unrolled(+, (1, 1.0)) == (1, 1.0) + @test @inferred Base.map_unrolled(+, (1, 1.0), (1, 1.0)) == (2, 2.0) + @test @inferred Base.partial_map(+, 1, (1, 1.0)) == (2, 2.0) + @test @inferred Base.partial_map(+, 1, (1, 1.0), (1, 1.0)) == (3, 3.0) + @test @inferred Base.reduce_unrolled(+, 1, 1.0) == 2.0 +end