Skip to content

Commit

Permalink
Add default IsScalar trait
Browse files Browse the repository at this point in the history
  • Loading branch information
pabloferz committed Jun 21, 2016
1 parent 1069aae commit 1ed693d
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 3 deletions.
3 changes: 2 additions & 1 deletion base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ promote_rule{T,n,S}(::Type{Array{T,n}}, ::Type{Array{S,n}}) = Array{promote_type

# make a collection similar to `c` and appropriate for collecting `itr`
_similar_for(c::AbstractArray, T, itr, ::SizeUnknown) = similar(c, T, 0)
_similar_for(c::AbstractArray, T, itr, ::IsScalar) = similar(c, T, ())
_similar_for(c::AbstractArray, T, itr, ::HasLength) = similar(c, T, Int(length(itr)::Integer))
_similar_for(c::AbstractArray, T, itr, ::HasShape) = similar(c, T, convert(Dims,size(itr)))
_similar_for(c, T, itr, isz) = similar(c, T)
Expand All @@ -226,7 +227,7 @@ collect(itr) = _collect(1:1 #= Array =#, itr, iteratoreltype(itr), iteratorsize(

collect_similar(cont, itr) = _collect(cont, itr, iteratoreltype(itr), iteratorsize(itr))

_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape}) =
_collect(cont, itr, ::HasEltype, isz::Union{HasLength,HasShape,IsScalar}) =
copy!(_similar_for(cont, eltype(itr), itr, isz), itr)

function _collect(cont, itr, ::HasEltype, isz::SizeUnknown)
Expand Down
5 changes: 4 additions & 1 deletion base/generator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,19 @@ end

abstract IteratorSize
immutable SizeUnknown <: IteratorSize end
immutable IsScalar <: IteratorSize end
immutable HasLength <: IteratorSize end
immutable HasShape <: IteratorSize end
immutable IsInfinite <: IteratorSize end

iteratorsize(x) = iteratorsize(typeof(x))
iteratorsize(::Type) = HasLength() # HasLength is the default
iteratorsize(::Type) = IsScalar() # IsScalar is the default

and_iteratorsize{T}(isz::T, ::T) = isz
and_iteratorsize(::HasLength, ::HasShape) = HasLength()
and_iteratorsize(::HasShape, ::HasLength) = HasLength()
and_iteratorsize(::IsScalar, ::Union{HasLength,HasShape}) = IsScalar()
and_iteratorsize(::Union{HasLength,HasShape}, ::IsScalar) = IsScalar()
and_iteratorsize(a, b) = SizeUnknown()

abstract IteratorEltype
Expand Down
4 changes: 4 additions & 0 deletions base/iterator.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ iteratoreltype{I}(::Type{Enumerate{I}}) = iteratoreltype(I)
abstract AbstractZipIterator

zip_iteratorsize(a, b) = and_iteratorsize(a,b) # as `and_iteratorsize` but inherit `Union{HasLength,IsInfinite}` of the shorter iterator
zip_iteratorsize(::IsScalar, ::IsInfinite) = IsScalar()
zip_iteratorsize(::HasLength, ::IsInfinite) = HasLength()
zip_iteratorsize(::HasShape, ::IsInfinite) = HasLength()
zip_iteratorsize(a::IsInfinite, b) = zip_iteratorsize(b,a)
Expand Down Expand Up @@ -413,6 +414,9 @@ iteratorsize{I1,I2}(::Type{Prod{I1,I2}}) = prod_iteratorsize(iteratorsize(I1),it
((x[1][1],x[1][2]...), x[2])
end

prod_iteratorsize(::IsScalar, ::IsScalar) = IsScalar()
prod_iteratorsize(::IsScalar, isz::Union{HasLength,HasShape}) = isz
prod_iteratorsize(isz::Union{HasLength,HasShape}, ::IsScalar) = isz
prod_iteratorsize(::Union{HasLength,HasShape}, ::Union{HasLength,HasShape}) = HasShape()
# products can have an infinite iterator
prod_iteratorsize(::IsInfinite, ::IsInfinite) = IsInfinite()
Expand Down
2 changes: 1 addition & 1 deletion base/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ ndims(x::Number) = 0
ndims{T<:Number}(::Type{T}) = 0
length(x::Number) = 1
endof(x::Number) = 1
iteratorsize{T<:Number}(::Type{T}) = HasShape()
iteratorsize{T<:Number}(::Type{T}) = IsScalar()

getindex(x::Number) = x
function getindex(x::Number, i::Integer)
Expand Down
2 changes: 2 additions & 0 deletions base/tuple.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ getindex(t::Tuple, b::AbstractArray{Bool}) = getindex(t,find(b))

## iterating ##

iteratorsize(::Tuple) = HasLength()

start(t::Tuple) = 1
done(t::Tuple, i::Int) = (length(t) < i)
next(t::Tuple, i::Int) = (t[i], i+1)
Expand Down

0 comments on commit 1ed693d

Please sign in to comment.