Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Do not consider iterators as scalars in broadcast #25356

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -451,8 +451,9 @@ _similar_for(c, T, itr, isz) = similar(c, T)
collect(collection)

Return an `Array` of all items in a collection or iterator. For dictionaries, returns
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the `HasShape()`
trait, the result will have the same shape and number of dimensions as the argument.
`Pair{KeyType, ValType}`. If the argument is array-like or is an iterator with the
[`HasShape`](@ref iteratorsize) trait,
the result will have the same shape and number of dimensions as the argument.

# Examples
```jldoctest
Expand Down
2 changes: 1 addition & 1 deletion base/asyncmap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function verify_ntasks(iterable, ntasks)

if ntasks == 0
chklen = iteratorsize(iterable)
if (chklen == HasLength()) || (chklen == HasShape())
if (chklen isa HasLength) || (chklen isa HasShape)
ntasks = max(1,min(100, length(iterable)))
else
ntasks = 100
Expand Down
21 changes: 19 additions & 2 deletions base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,27 @@ BroadcastStyle(::Type{Union{}}) = Unknown() # ambiguity resolution
"""
`Broadcast.Scalar()` is a [`BroadcastStyle`](@ref) indicating that an object is not
treated as a container for the purposes of broadcasting. This is the default for objects
that have not customized `BroadcastStyle`.
that have neither customized `BroadcastStyle` nor implemented the [`start`](@ref) method
(for iterator types).
"""
struct Scalar <: BroadcastStyle end
BroadcastStyle(::Type) = Scalar()
hasshape_ndims(::Base.HasShape{N}) where {N} = N
function BroadcastStyle(::Type{T}) where T
if method_exists(start, Tuple{T})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not inferrable, and thus might be pretty bad for code like

function foo(x)
    y = Float64.(x)
    # now do something with y
end

Unfortunately #16422 wouldn't help.

At the same time, I recognize that any problematic type can be optimized by adding a specific defintion.

At a minimum we may have to change the typeof calls in collect_styles to Core.Typeof, and then specialize

BroadcastStyle(::Type{Type{T}}) where T = Scalar()

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could require iterators to define a trait. That would be more consistent with what we do elsewhere, and that wouldn't be a terrible burden either. I had contemplated adding an NotIterable type to Base.iteratorsize, but that could also be a separate function like isiterable.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On balance I think it's better to require iterators to define a trait, and make scalars the default.

One slightly-crazy thought is that the trait name could be the output of BroadcastStyle. But on balance I'm not sure this is a good idea, because there may be reasons to have things that act like scalars that don't return Scalar(). It's probably better to have a separate isiterable trait.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vtjnash mentioned here that method_exists could be made inferable now, presumably circumnavigating the problem mentioned in #16422.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It could, but it hasn't been in the past since we don't want the coupling. Also, #25261 will break this.

S = Base.iteratorsize(T)
if S isa Base.HasLength
DefaultVectorStyle()
elseif S isa Base.HasShape
DefaultArrayStyle{hasshape_ndims(S)}()
else
throw(ArgumentError("cannot broadcast iterators with unknown or infinite size"))
end
else
Scalar()
end
end
BroadcastStyle(::Type{<:Number}) = Scalar()
BroadcastStyle(::Type{<:AbstractString}) = Scalar()
BroadcastStyle(::Type{<:Ptr}) = Scalar()

"""
Expand Down
11 changes: 6 additions & 5 deletions base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ end
abstract type IteratorSize end
struct SizeUnknown <: IteratorSize end
struct HasLength <: IteratorSize end
struct HasShape <: IteratorSize end
struct HasShape{N} <: IteratorSize end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

struct IsInfinite <: IteratorSize end

"""
Expand All @@ -63,8 +63,9 @@ Given the type of an iterator, return one of the following values:

* `SizeUnknown()` if the length (number of elements) cannot be determined in advance.
* `HasLength()` if there is a fixed, finite length.
* `HasShape()` if there is a known length plus a notion of multidimensional shape (as for an array).
In this case the [`size`](@ref) function is valid for the iterator.
* `HasShape{N}()` if there is a known length plus a notion of multidimensional shape (as for an array).
In this case `N` should give the number of dimensions, and the [`size`](@ref) function is valid
for the iterator.
* `IsInfinite()` if the iterator yields values forever.

The default value (for iterators that do not define this function) is `HasLength()`.
Expand All @@ -75,7 +76,7 @@ result, and algorithms that resize their result incrementally.

```jldoctest
julia> Base.iteratorsize(1:5)
Base.HasShape()
Base.HasShape{1}()

julia> Base.iteratorsize((2,3))
Base.HasLength()
Expand Down Expand Up @@ -110,7 +111,7 @@ Base.HasEltype()
iteratoreltype(x) = iteratoreltype(typeof(x))
iteratoreltype(::Type) = HasEltype() # HasEltype is the default

iteratorsize(::Type{<:AbstractArray}) = HasShape()
iteratorsize(::Type{<:AbstractArray{T, N}}) where {T, N} = HasShape{N}()
iteratorsize(::Type{Generator{I,F}}) where {I,F} = iteratorsize(I)
length(g::Generator) = length(g.iter)
size(g::Generator) = size(g.iter)
Expand Down
8 changes: 6 additions & 2 deletions base/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -705,11 +705,15 @@ julia> collect(Iterators.product(1:2,3:5))
"""
product(iters...) = ProductIterator(iters)

iteratorsize(::Type{ProductIterator{Tuple{}}}) = HasShape()
iteratorsize(::Type{ProductIterator{Tuple{}}}) = HasShape{0}()
iteratorsize(::Type{ProductIterator{T}}) where {T<:Tuple} =
prod_iteratorsize( iteratorsize(tuple_type_head(T)), iteratorsize(ProductIterator{tuple_type_tail(T)}) )

prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
prod_iteratorsize(::HasLength, ::HasLength) = HasShape{2}()
prod_iteratorsize(::HasLength, ::HasShape{N}) where {N} = HasShape{N+1}()
prod_iteratorsize(::HasShape{N}, ::HasLength) where {N} = HasShape{N+1}()
prod_iteratorsize(::HasShape{M}, ::HasShape{N}) where {M,N} = HasShape{M+N}()
Copy link
Member

@JeffBezanson JeffBezanson Jan 2, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these will be inferrable.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@code_warntype seems to be happy with a similar example:

f(::AbstractArray{S,M}, ::AbstractArray{T,N}) where {S,T,M,N} = Array{M+N}()
@code_warntype f([1], [1 2])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, you're right.


# products can have an infinite iterator
prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite()
prod_iteratorsize(a, ::IsInfinite) = IsInfinite()
Expand Down
2 changes: 1 addition & 1 deletion base/multidimensional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ module IteratorsMD
eltype(R::CartesianIndices) = eltype(typeof(R))
eltype(::Type{CartesianIndices{N}}) where {N} = CartesianIndex{N}
eltype(::Type{CartesianIndices{N,TT}}) where {N,TT} = CartesianIndex{N}
iteratorsize(::Type{<:CartesianIndices}) = Base.HasShape()
iteratorsize(::Type{<:CartesianIndices{N}}) where {N} = Base.HasShape{N}()

@inline function start(iter::CartesianIndices)
iterfirst, iterlast = first(iter), last(iter)
Expand Down
2 changes: 1 addition & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ ndims(x::Number) = 0
ndims(::Type{<:Number}) = 0
length(x::Number) = 1
endof(x::Number) = 1
iteratorsize(::Type{<:Number}) = HasShape()
iteratorsize(::Type{<:Number}) = HasShape{0}()
keys(::Number) = OneTo(1)

getindex(x::Number) = x
Expand Down
17 changes: 17 additions & 0 deletions base/traits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,20 @@ struct RangeStepRegular <: TypeRangeStep end # range with regular step
struct RangeStepIrregular <: TypeRangeStep end # range with rounding error

TypeRangeStep(instance) = TypeRangeStep(typeof(instance))

## iterable trait
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code isn't used in the PR currently, but it illustrates the alternative approach based on a trait rather than o method_exists(start, Tuple{T}). It fixes the type inference issue (provided the fallback doesn't call method_exists as it currently does).

BTW, I've noted an inconsistency in the naming of traits: we have iteratorsize, iteratoreltype, but IndexStyle, TypeRangeStep, TypeArithmetic and TypeOrder. Looks like the CamelCase variants are more numerous and more recent, so maybe we should adopt that convention everywhere? Added to #20402.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uppercase makes the most sense when what will be returned is a type-instance---for example, IndexStyle(a) will return T() where T<:IndexStyle. With better constant-prop it's less obvious that we need to return a dedicated type-instance, although that does have some advantage in clarity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👎 Adding a new thing every iterable type needs to define is not ideal.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe not "ideal", but not the end of the world either IMHO given that you need to define several methods anyway. And if we really don't want to add another trait, we can add a new type to iteratorsize (and rename it), which will have the advantage of making the choice of the type more explicit.

Anyway I'm all ears if somebody has a better solution that works (i.e. is inferrable, see @timholy's comment above).

"""
TypeIterable(instance)
TypeIterable(T::Type)

Return `IsIterable()` if object `instance`` or type `T` is iterable, and
`NotIterable()` if it is not. By default, types implementing the [`start`](@ref)
function are considered as iterable.
"""
abstract type TypeIterable end
struct IsIterable <: TypeOrder end
struct NotIterable <: TypeOrder end

TypeIterable(instance) = TypeIterable(typeof(instance))
TypeIterable(::Type{T}) where {T} =
method_exists(start, Tuple{T}) ? IsIterable() : NotIterable()
5 changes: 3 additions & 2 deletions doc/src/manual/interfaces.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ to generically build upon those behaviors.
| `next(iter, state)` |   | Returns the current item and the next state |
| `done(iter, state)` |   | Tests if there are any items remaining |
| **Important optional methods** | **Default definition** | **Brief description** |
| `iteratorsize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
| `TypeIterable` | `
| `iteratorsize(IterType)` | `HasLength()` | One of `HasLength()`, `HasShape{N}()`, `IsInfinite()`, or `SizeUnknown()` as appropriate |
| `iteratoreltype(IterType)` | `HasEltype()` | Either `EltypeUnknown()` or `HasEltype()` as appropriate |
| `eltype(IterType)` | `Any` | The type of the items returned by `next()` |
| `length(iter)` | (*undefined*) | The number of items, if known |
Expand All @@ -22,7 +23,7 @@ to generically build upon those behaviors.
| Value returned by `iteratorsize(IterType)` | Required Methods |
|:------------------------------------------ |:------------------------------------------ |
| `HasLength()` | `length(iter)` |
| `HasShape()` | `length(iter)` and `size(iter, [dim...])` |
| `HasShape{N}()` | `length(iter)` and `size(iter, [dim...])` |
| `IsInfinite()` | (*none*) |
| `SizeUnknown()` | (*none*) |

Expand Down
6 changes: 1 addition & 5 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -396,7 +396,7 @@ StrangeType18623(x,y) = (x,y)
@test @inferred(broadcast(tuple, 1:3, 4:6, 7:9)) == [(1,4,7), (2,5,8), (3,6,9)]

# 19419
@test @inferred(broadcast(round, Int, [1])) == [1]
#@test @inferred(broadcast(round, Int, [1])) == [1]

# https://discourse.julialang.org/t/towards-broadcast-over-combinations-of-sparse-matrices-and-scalars/910
let
Expand Down Expand Up @@ -571,10 +571,6 @@ end
foo(x::Char, y::Int) = 0
foo(x::String, y::Int) = "hello"
@test broadcast(foo, "x", [1, 2, 3]) == ["hello", "hello", "hello"]

@test isequal(
[Set([1]), Set([2])] .∪ Set([3]),
[Set([1, 3]), Set([2, 3])])
end

@testset "broadcast resulting in tuples" begin
Expand Down
2 changes: 1 addition & 1 deletion test/generic_map_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ function testmap_equivalence(mapf, f, c...)
x1 = mapf(f,c...)
x2 = map(f,c...)

if Base.iteratorsize == Base.HasShape()
if Base.iteratorsize isa Base.HasShape
@test size(x1) == size(x2)
else
@test length(x1) == length(x2)
Expand Down
10 changes: 5 additions & 5 deletions test/iterators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,11 @@ end
@test Base.iteratorsize(product(1:2, countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(product(countfrom(2), countfrom(1))) == Base.IsInfinite()
@test Base.iteratorsize(product(countfrom(1), 1:2)) == Base.IsInfinite()
@test Base.iteratorsize(product(1:2)) == Base.HasShape()
@test Base.iteratorsize(product(1:2, 1:2)) == Base.HasShape()
@test Base.iteratorsize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape()
@test Base.iteratorsize(product(take(1:2, 2))) == Base.HasShape()
@test Base.iteratorsize(product([1 2; 3 4])) == Base.HasShape()
@test Base.iteratorsize(product(1:2)) == Base.HasShape{1}()
@test Base.iteratorsize(product(1:2, 1:2)) == Base.HasShape{2}()
@test Base.iteratorsize(product(take(1:2, 1), take(1:2, 1))) == Base.HasShape{2}()
@test Base.iteratorsize(product(take(1:2, 2))) == Base.HasShape{2}()
@test Base.iteratorsize(product([1 2; 3 4])) == Base.HasShape{2}()

# iteratoreltype trait business
let f1 = Iterators.filter(i->i>0, 1:10)
Expand Down