Skip to content

Commit

Permalink
add length asserts to reverse(::Zip)
Browse files Browse the repository at this point in the history
  • Loading branch information
adienes committed Jul 6, 2023
1 parent c09a199 commit eef0324
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 1 deletion.
17 changes: 16 additions & 1 deletion base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,16 @@ function _zip_min_length(is)
end
end
_zip_min_length(is::Tuple{}) = nothing
function _zip_lengths_finite_equal(is)
i = is[1]
b, n = _zip_lengths_finite_equal(tail(is))
if IteratorSize(i) isa Union{IsInfinite, SizeUnknown}
return (false, nothing)
else
return (b && (n === nothing || n == length(i)), length(i))
end
end
_zip_lengths_finite_equal(is::Tuple{}) = (true, nothing)
size(z::Zip) = _promote_tuple_shape(Base.map(size, z.is)...)
axes(z::Zip) = _promote_tuple_shape(Base.map(axes, z.is)...)
_promote_tuple_shape((a,)::Tuple{OneTo}, (b,)::Tuple{OneTo}) = (intersect(a, b),)
Expand Down Expand Up @@ -468,8 +478,13 @@ zip_iteratoreltype() = HasEltype()
zip_iteratoreltype(a) = a
zip_iteratoreltype(a, tail...) = and_iteratoreltype(a, zip_iteratoreltype(tail...))

reverse(z::Zip) = Zip(Base.map(reverse, z.is)) # n.b. we assume all iterators are the same length
last(z::Zip) = getindex.(z.is, minimum(Base.map(lastindex, z.is)))
function reverse(z::Zip)
if !first(_zip_lengths_finite_equal(z.is))
throw(ArgumentError("Zipped iterators of unknown, infinite, or unequal lengths must be collected before reversing"))
end
Zip(Base.map(reverse, z.is))
end

# filter

Expand Down
4 changes: 4 additions & 0 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@ using Dates: Date, Day
# issue #4718
@test collect(Iterators.filter(x->x[1], zip([true, false, true, false],"abcd"))) == [(true,'a'),(true,'c')]

# issue #45085
@test_throws ArgumentError Iterators.reverse(zip("abc", "abcd"))
@test_throws ArgumentError Iterators.reverse(zip("abc", Iterators.cycle("ab")))

let z = zip(1:2)
@test size(z) == (2,)
@test collect(z) == [(1,), (2,)]
Expand Down

0 comments on commit eef0324

Please sign in to comment.