From 71291dfd7ac5cc881bddb6a3bb0ce331bb4d11d7 Mon Sep 17 00:00:00 2001 From: Jakob Nybo Nissen Date: Wed, 25 Aug 2021 00:42:41 +0200 Subject: [PATCH] Fix collect on stateful generator (#41919) Previously this code would drop 1 from the length of some generators. Fixes #35530 (cherry picked from commit 8364a4ccd8885fa8d8c78094c7653c58e33d9f0d) --- base/array.jl | 10 +++++++--- test/iterators.jl | 9 +++++++++ 2 files changed, 16 insertions(+), 3 deletions(-) diff --git a/base/array.jl b/base/array.jl index e7b87cf9739cef..cdc7ad7fedbbca 100644 --- a/base/array.jl +++ b/base/array.jl @@ -666,8 +666,10 @@ else end end -_array_for(::Type{T}, itr, ::HasLength) where {T} = Vector{T}(undef, Int(length(itr)::Integer)) -_array_for(::Type{T}, itr, ::HasShape{N}) where {T,N} = similar(Array{T,N}, axes(itr)) +_array_for(::Type{T}, itr, isz::HasLength) where {T} = _array_for(T, itr, isz, length(itr)) +_array_for(::Type{T}, itr, isz::HasShape{N}) where {T,N} = _array_for(T, itr, isz, axes(itr)) +_array_for(::Type{T}, itr, ::HasLength, len) where {T} = Vector{T}(undef, len) +_array_for(::Type{T}, itr, ::HasShape{N}, axs) where {T,N} = similar(Array{T,N}, axs) function collect(itr::Generator) isz = IteratorSize(itr.iter) @@ -675,12 +677,14 @@ function collect(itr::Generator) if isa(isz, SizeUnknown) return grow_to!(Vector{et}(), itr) else + shape = isz isa HasLength ? length(itr) : axes(itr) y = iterate(itr) if y === nothing return _array_for(et, itr.iter, isz) end v1, st = y - collect_to_with_first!(_array_for(typeof(v1), itr.iter, isz), v1, itr, st) + arr = _array_for(typeof(v1), itr.iter, isz, shape) + return collect_to_with_first!(arr, v1, itr, st) end end diff --git a/test/iterators.jl b/test/iterators.jl index b9bec84bf9a581..6007a59464f5a8 100644 --- a/test/iterators.jl +++ b/test/iterators.jl @@ -291,6 +291,15 @@ let (a, b) = (1:3, [4 6; end end +# collect stateful iterator +let + itr = (i+1 for i in Base.Stateful([1,2,3])) + @test collect(itr) == [2, 3, 4] + A = zeros(Int, 0, 0) + itr = (i-1 for i in Base.Stateful(A)) + @test collect(itr) == Int[] # Stateful do not preserve shape +end + # with 1D inputs let a = 1:2, b = 1.0:10.0,