From 70fc3cdc11b086fc6c70595006d2a8398d5d7e6b Mon Sep 17 00:00:00 2001 From: Branwen Snelling Date: Fri, 4 Mar 2022 14:28:52 +0000 Subject: [PATCH] define _maxndims methods for small tuples to help inference --- base/broadcast.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/base/broadcast.jl b/base/broadcast.jl index 20873adbf1bd9..1896e5edad105 100644 --- a/base/broadcast.jl +++ b/base/broadcast.jl @@ -263,11 +263,15 @@ end Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}() Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(fieldtype(BC, 2)) -function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} - N isa Integer && return N - _maxndims(fieldtype(BC, 2)) +Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N<:Integer} = N + +_maxndims(T::Type{<:Tuple}) = reduce(max, (ntuple(n -> _ndims(fieldtype(T, n)), Base._counttuple(T)))) +_maxndims(::Type{<:Tuple{T}}) where {T} = ndims(T) +_maxndims(::Type{<:Tuple{T}}) where {T<:Tuple} = _ndims(T) +function _maxndims(::Type{<:Tuple{T, S}}) where {T, S} + return T<:Tuple || S<:Tuple ? max(_ndims(T), _ndims(S)) : max(ndims(T), ndims(S)) end -Base.@pure _maxndims(T) = mapfoldl(_ndims, max, fieldtypes(T)) + _ndims(x) = ndims(x) _ndims(::Type{<:Tuple}) = 1