Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unzip #33324

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
189 changes: 189 additions & 0 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2467,3 +2467,192 @@ end

intersect(itr, itrs...) = _shrink(intersect!, itr, itrs)
setdiff( itr, itrs...) = _shrink(setdiff!, itr, itrs)

map_unrolled(call, variables::Tuple{}) = ()
map_unrolled(call, variables) =
call(first(variables)), map_unrolled(call, tail(variables))...
bramtayl marked this conversation as resolved.
Show resolved Hide resolved

map_unrolled(call, variables1::Tuple{}, variables2::Tuple{}) = ()
map_unrolled(call, variables1, variables2) =
call(first(variables1), first(variables2)),
map_unrolled(call, tail(variables1), tail(variables2))...

partial_map(call, fixed, variables::Tuple{}) = ()
partial_map(call, fixed, variables) =
call(fixed, first(variables)), partial_map(call, fixed, tail(variables))...
bramtayl marked this conversation as resolved.
Show resolved Hide resolved

partial_map(call, fixed, variables1::Tuple{}, variables2::Tuple{}) = ()
partial_map(call, fixed, variables1, variables2) =
call(fixed, first(variables1), first(variables2)),
partial_map(call, fixed, tail(variables1), tail(variables2))...

function reduce_unrolled(call, item)
item
end
function reduce_unrolled(call, item1, item2, rest...)
reduce_unrolled(call, call(item1, item2), rest...)
bramtayl marked this conversation as resolved.
Show resolved Hide resolved
end

struct Rows{Row, Dimensions, Columns} <: AbstractArray{Row, Dimensions}
columns::Columns
end

function compare_axes(reference_axes, item)
axes(item) == reference_axes
end
function same_axes(first_column, rest...)
reduce_unrolled(&, partial_map(compare_axes, axes(first_column), rest)...)
end
function same_axes()
true
end

@propagate_inbounds function Rows{Row, Dimension}(columns::Columns) where {Row, Dimension, Columns}
@boundscheck if !same_axes(columns...)
throw(DimensionMismatch("All arguments to `Rows` must have the same axes"))
end
Rows{Row, Dimension, Columns}(columns)
end

get_model(columns) = first(columns)
bramtayl marked this conversation as resolved.
Show resolved Hide resolved
get_model(::Tuple{}) = 1:0

@propagate_inbounds Rows(columns) = Rows{
Tuple{map_unrolled(eltype, columns)...},
ndims(get_model(columns))
}(columns)

parent(rows::Rows) = get_model(rows.columns)

get_columns(rows::Rows) = rows.columns

axes(rows::Rows, dimensions...) = axes(parent(rows), dimensions...)
size(rows::Rows, dimensions...) = size(parent(rows), dimensions...)

@propagate_inbounds column_getindex(an_index, column) = column[an_index...]
@propagate_inbounds getindex(rows::Rows, an_index...) = partial_map(
column_getindex,
an_index,
rows.columns
)

@propagate_inbounds function column_setindex!(an_index, column, value)
column[an_index...] = value
end
@propagate_inbounds function setindex!(rows::Rows, row, an_index...)
partial_map(column_setindex!, an_index, rows.columns, row)
end

function push!(rows::Rows, row)
map_unrolled(push!, rows.columns, row)
end

val_fieldtypes(something) = ()
@pure val_fieldtypes(a_type::DataType) =
if a_type.abstract || (a_type.name == Tuple.name && isvatuple(a_type))
()
else
map_unrolled(Val, (a_type.types...,))
end

similar_val((model, dimensions), ::Val{Value}) where {Value} =
similar(model, Value, dimensions)
similar(rows::Rows, ::Type{ARow}, dimensions::Dims) where {ARow} =
@inbounds Rows(partial_map(
similar_val,
(parent(rows), dimensions),
val_fieldtypes(ARow)
))

empty(column::Rows{OldRow}, ::Type{NewRow} = OldRow) where {OldRow, NewRow} =
similar(column, NewRow)

function widen_column(::HasLength, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element}
@inbounds column[an_index] = item
column
end
widen_column(::HasLength, new_length, an_index, name, column::Array, item) =
setindex_widen_up_to(column, item, an_index)

function widen_column(::SizeUnknown, new_length, an_index, column::Array{Element}, item::Item) where {Element, Item <: Element}
push!(column, item)
column
end
widen_column(::SizeUnknown, new_length, an_index, column::Array, item) =
push_widen(column, item)

function widen_column(iterator_size, new_length, an_index, ::Missing, item::Item) where {Item}
new_column = Array{Union{Missing, Item}}(missing, new_length)
@inbounds new_column[an_index] = item
new_column
end
widen_column(iterator_size, new_length, an_index, ::Missing, ::Missing) =
Array{Missing}(missing, new_length)

widen_column(fixeds, variables) = widen_column(fixeds..., variables...)

get_new_length(::SizeUnknown, rows, an_index) = an_index
get_new_length(::HasLength, rows, an_index) = length(rows)

zip_first_missing() = ()
zip_first_missing(head, the_tail...) = (missing, head), zip_first_missing(the_tail...)...
zip_second_missing() = ()
zip_second_missing(head, the_tail...) = (head, missing), zip_second_missing(the_tail...)...
zip_missing(::Tuple{}, ::Tuple{}) = ()
zip_missing(::Tuple{}, longer) = zip_first_missing(longer...)
zip_missing(longer, ::Tuple{}) = zip_second_missing(longer...)
zip_missing(tuple1, tuple2) =
(first(tuple1), first(tuple2)),
zip_missing(tail(tuple1), tail(tuple2))...

function widen_columns(iterator_size, rows, row, an_index = length(rows) + 1)
columns = rows.columns
@inbounds Rows(partial_map(
widen_column,
(iterator_size, get_new_length(iterator_size, rows, an_index), an_index),
zip_missing(rows.columns, row)
))
end

push_widen(rows::Rows, row) = widen_columns(SizeUnknown(), rows, row)
setindex_widen_up_to(rows::Rows, row, an_index) =
widen_columns(HasLength(), rows, row, an_index)

"""
unzip(rows)

Collect into columns. Always eager, see [`to_columns`](@ref) for a lazy version.

```jldoctest
julia> using Base: Generator

julia> using Test: @inferred

julia> stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0);

julia> @inferred unzip(Generator(stable, 1:4))
([1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0], [1, 2, 3, 4], [1.0, 2.0, 3.0, 4.0])

julia> unstable(x) =
if x == 2
(x, x + 0.0, x, x + 0.0)
else
(x, x + 0.0)
end;

julia> unzip(Generator(unstable, 1:3))
([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing])

julia> unzip(Iterators.filter(row -> true, Generator(unstable, 1:3)))
([1, 2, 3], [1.0, 2.0, 3.0], Union{Missing, Int64}[missing, 2, missing], Union{Missing, Float64}[missing, 2.0, missing])

julia> unzip(Iterators.filter(row -> false, Generator(unstable, 1:4)))
()
```
"""
unzip(rows) = get_columns(_collect(
(@inbounds Rows(())),
rows,
IteratorEltype(rows),
IteratorSize(rows)
))
1 change: 1 addition & 0 deletions base/exports.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ export
⊋,
∩,
∪,
unzip,

# strings
ascii,
Expand Down
40 changes: 40 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -995,3 +995,43 @@ end
@test getindex(x) == getindex(x, CartesianIndex()) == 10
end
end

@testset "Unzip" begin
stable(x) = (x, x + 0.0, x, x + 0.0, x, x + 0.0);
stable_result = @inferred unzip(Base.Generator(stable, 1:2))
@test stable_result == (1:2, 1:2, 1:2, 1:2, 1:2, 1:2)
@test typeof(stable_result) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Int}, Vector{Float64},
Vector{Int}, Vector{Float64}
}

unstable(x) =
if x == 2
(x, x + 0.0, x, x + 0.0)
else
(x, x + 0.0)

end

unstable_result_1 = unzip(Base.Generator(unstable, 1:3))
@test isequal(unstable_result_1, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing]))
@test typeof(unstable_result_1) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}},
Vector{Union{Missing, Float64}}
}

unstable_result_2 = unzip(Iterators.filter(row -> true, Base.Generator(unstable, 1:3)))
@test isequal(unstable_result_2, (1:3, 1:3, [missing, 2, missing], [missing, 2, missing]))
@test typeof(unstable_result_2) ==
Tuple{
Vector{Int}, Vector{Float64}, Vector{Union{Missing, Int}},
Vector{Union{Missing, Float64}}
}

unstable_result_3 = unzip(Iterators.filter(row -> false, Base.Generator(unstable, 1:4)))
@test unstable_result_3 == ()

@test_throws DimensionMismatch Base.Rows((1:2, 1:3))
end