Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unzip #33324

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2467,3 +2467,167 @@ end

intersect(itr, itrs...) = _shrink(intersect!, itr, itrs)
setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs)

struct Rows{Row, Dimensions, Columns} <: AbstractArray{Row, Dimensions}
columns::Columns
end

function compare_axes(reference_axes, item)
axes(item) == reference_axes
end
function same_axes(first_column, rest...)
reduce_unrolled(&, partial_map(compare_axes, axes(first_column), rest)...)
end
function same_axes()
true
end

@propagate_inbounds function Rows{Row, Dimension}(columns::Columns) where {Row, Dimension, Columns}
@boundscheck if !same_axes(columns...)
throw(DimensionMismatch("All arguments to `Rows` must have the same axes"))
end
Rows{Row, Dimension, Columns}(columns)
end

dummy_column(columns) = first(columns)
dummy_column(::Tuple{}) = 1:0

@propagate_inbounds Rows(columns) = Rows{
Tuple{map_unrolled(eltype, columns)...},
ndims(dummy_column(columns))
}(columns)

parent(rows::Rows) = dummy_column(rows.columns)

get_columns(rows::Rows) = rows.columns

axes(rows::Rows, dimensions...) = axes(parent(rows), dimensions...)
size(rows::Rows, dimensions...) = size(parent(rows), dimensions...)

@propagate_inbounds column_getindex(an_index, column) = column[an_index...]
@propagate_inbounds getindex(rows::Rows, an_index...) = partial_map(
column_getindex,
an_index,
rows.columns
)

@propagate_inbounds function column_setindex!(an_index, column, value)
column[an_index...] = value
end
@propagate_inbounds function setindex!(rows::Rows, row, an_index...)
partial_map(column_setindex!, an_index, rows.columns, row)
end

function push!(rows::Rows, row)
map_unrolled(push!, rows.columns, row)
end

val_fieldtypes(something) = ()
@pure val_fieldtypes(a_type::DataType) =
if a_type.abstract || (a_type.name == Tuple.name && isvatuple(a_type))
()
else
map_unrolled(Val, (a_type.types...,))
end

similar_val((model, dimensions), ::Val{Value}) where {Value} =
similar(model, Value, dimensions)
similar(rows::Rows, ::Type{ARow}, dimensions::Dims) where {ARow} =
@inbounds Rows(partial_map(
similar_val,
(parent(rows), dimensions),
val_fieldtypes(ARow)
))

empty(column::Rows{OldRow}, ::Type{NewRow} = OldRow) where {OldRow, NewRow} =
similar(column, NewRow)

function widen_column(::HasLength, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element}
@inbounds column[an_index] = item
column
end
widen_column(::HasLength, new_length, an_index, name, column::Array, item) =
setindex_widen_up_to(column, item, an_index)

function widen_column(::SizeUnknown, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element}
push!(column, item)
column
end
widen_column(::SizeUnknown, new_length, an_index, column::Array, item) =
push_widen(column, item)

function widen_column(iterator_size, new_length, an_index, ::Missing, item::Item) where {Item}
new_column = Array{Union{Missing, Item}}(missing, new_length)
@inbounds new_column[an_index] = item
new_column
end
widen_column(iterator_size, new_length, an_index, ::Missing, ::Missing) =
Array{Missing}(missing, new_length)

widen_column(fixeds, variables) = widen_column(fixeds..., variables...)

get_new_length(::SizeUnknown, rows, an_index) = an_index
get_new_length(::HasLength, rows, an_index) = length(rows)

zip_first_missing() = ()
zip_first_missing(head, the_tail...) = (missing, head), zip_first_missing(the_tail...)...
zip_second_missing() = ()
zip_second_missing(head, the_tail...) = (head, missing), zip_second_missing(the_tail...)...
zip_missing(::Tuple{}, ::Tuple{}) = ()
zip_missing(::Tuple{}, longer) = zip_first_missing(longer...)
zip_missing(longer, ::Tuple{}) = zip_second_missing(longer...)
zip_missing(tuple1, tuple2) =
(first(tuple1), first(tuple2)),
zip_missing(tail(tuple1), tail(tuple2))...

function widen_columns(iterator_size, rows, row, an_index = length(rows) + 1)
columns = rows.columns
@inbounds Rows(partial_map(
widen_column,
(iterator_size, get_new_length(iterator_size, rows, an_index), an_index),
zip_missing(rows.columns, row)
))
end

push_widen(rows::Rows, row) = widen_columns(SizeUnknown(), rows, row)
setindex_widen_up_to(rows::Rows, row, an_index) =
widen_columns(HasLength(), rows, row, an_index)

"""
unzip(rows)

Collect into columns. Always eager, see [`to_columns`](@ref) for a lazy version.

```jldoctest
julia> using Base: Generator

julia> using Test: @inferred

julia> stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0);

julia> @inferred unzip(Generator(stable, 1:4))
([1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0])

julia> unstable(x) =
if x == 2
(x, x + 0.0, x, x + 0.0)
else
(x, x + 0.0)
end;

julia> unzip(Generator(unstable, 1:3))
([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing])

julia> unzip(Iterators.filter(row -> true, Generator(unstable, 1:3)))
([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing])

julia> unzip(Iterators.filter(row -> false, Generator(unstable, 1:4)))
()
```
"""
unzip(rows) = get_columns(_collect(
(@inbounds Rows(())),
rows,
IteratorEltype(rows),
IteratorSize(rows)
))
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ export
⊋,
∩,
∪,
unzip,

# strings
ascii,
Expand Down
27 changes: 27 additions & 0 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -409,3 +409,30 @@ _tuple_any(f::Function, tf::Bool) = tf
Returns an empty tuple, `()`.
"""
empty(@nospecialize x::Tuple) = ()

# Unrolled operations
# A mini, unexported module for operations unrolled via recursion
map_unrolled(call, variables::Tuple{}) = ()
map_unrolled(call, variables) =
call(first(variables)), map_unrolled(call, tail(variables))...

map_unrolled(call, variables1::Tuple{}, variables2::Tuple{}) = ()
map_unrolled(call, variables1, variables2) =
call(first(variables1), first(variables2)),
map_unrolled(call, tail(variables1), tail(variables2))...

partial_map(call, fixed, variables::Tuple{}) = ()
partial_map(call, fixed, variables) =
call(fixed, first(variables)), partial_map(call, fixed, tail(variables))...

partial_map(call, fixed, variables1::Tuple{}, variables2::Tuple{}) = ()
partial_map(call, fixed, variables1, variables2) =
call(fixed, first(variables1), first(variables2)),
partial_map(call, fixed, tail(variables1), tail(variables2))...

function reduce_unrolled(call, item)
item
end
function reduce_unrolled(call, item1, item2, rest...)
reduce_unrolled(call, call(item1, item2), rest...)
end
40 changes: 40 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,43 @@ end
@test getindex(x) == getindex(x, CartesianIndex()) == 10
end
end

@testset "Unzip" begin
stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0);
stable_result = @inferred unzip(Base.Generator(stable, 1:2))
@test stable_result == (1:2, 1:2, 1:2, 1:2, 1:2, 1:2)
@test typeof(stable_result) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Int}, Vector{Float64},
Vector{Int}, Vector{Float64}
}

unstable(x) =
if x == 2
(x, x + 0.0, x, x + 0.0)
else
(x, x + 0.0)

end

unstable_result_1 = unzip(Base.Generator(unstable, 1:3))
@test isequal(unstable_result_1, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing]))
@test typeof(unstable_result_1) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}},
Vector{Union{Missing, Float64}}
}

unstable_result_2 = unzip(Iterators.filter(row -> true, Base.Generator(unstable, 1:3)))
@test isequal(unstable_result_2, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing]))
@test typeof(unstable_result_2) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}},
Vector{Union{Missing, Float64}}
}

unstable_result_3 = unzip(Iterators.filter(row -> false, Base.Generator(unstable, 1:4)))
@test unstable_result_3 == ()

@test_throws DimensionMismatch Base.Rows((1:2, 1:3))
end
8 changes: 8 additions & 0 deletions test/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,3 +451,11 @@ end

# tuple_type_tail on non-normalized vararg tuple
@test Base.tuple_type_tail(Tuple{Vararg{T, 3}} where T<:Real) == Tuple{Vararg{T, 2}} where T<:Real

@testset "unrolled" begin
@test @inferred Base.map_unrolled(+, (1, 1.0)) == (1, 1.0)
@test @inferred Base.map_unrolled(+, (1, 1.0), (1, 1.0)) == (2, 2.0)
@test @inferred Base.partial_map(+, 1, (1, 1.0)) == (2, 2.0)
@test @inferred Base.partial_map(+, 1, (1, 1.0), (1, 1.0)) == (3, 3.0)
@test @inferred Base.reduce_unrolled(+, 1, 1.0) == 2.0
end