From 99b3c8e79d78ee10feeb5b95ee9b70b67be971f4 Mon Sep 17 00:00:00 2001 From: Jeff Bezanson Date: Mon, 23 Dec 2013 14:43:21 -0500 Subject: [PATCH] implement new array assignment shape matching rule. fixes #4048, fixes #4383 this rule ignores singleton dimensions, and allows the last dimension of one side to match all trailing dimensions of the other. --- base/array.jl | 15 +-------- base/operators.jl | 82 +++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 73 insertions(+), 24 deletions(-) diff --git a/base/array.jl b/base/array.jl index 75824fe94f994..41f57a026220f 100644 --- a/base/array.jl +++ b/base/array.jl @@ -593,20 +593,7 @@ function setindex!(A::Array, x, I::Union(Real,AbstractArray)...) assign_cache = Dict() end X = x - nel = 1 - for idx in I - nel *= length(idx) - end - if length(X) != nel - throw(DimensionMismatch("")) - end - if ndims(X) > 1 - for i = 1:length(I) - if size(X,i) != length(I[i]) - throw(DimensionMismatch("")) - end - end - end + setindex_shape_check(X, I...) gen_array_index_map(assign_cache, storeind -> quote A[$storeind] = X[refind] refind += 1 diff --git a/base/operators.jl b/base/operators.jl index cf87ef12781d4..5f6c31bd10746 100644 --- a/base/operators.jl +++ b/base/operators.jl @@ -206,21 +206,83 @@ index_shape(I::Real...) = () index_shape(i, I...) = tuple(length(i), index_shape(I...)...) # check for valid sizes in A[I...] = X where X <: AbstractArray +# we want to allow dimensions that are equal up to permutation, but only +# for permutations that leave array elements in the same linear order. +# those are the permutations that preserve the order of the non-singleton +# dimensions. function setindex_shape_check(X::AbstractArray, I...) + li = length(I) + ii = 1 nel = 1 - for idx in I - nel *= length(idx) - end - if length(X) != nel - error("dimensions must match") - end - if ndims(X) > 1 - for i = 1:length(I) - if size(X,i) != length(I[i]) - error("dimensions must match") + xi = 1 + ndx = ndims(X) + match = true + while ii < li + lii = length(I[ii])::Int + ii += 1 + if lii != 1 + nel *= lii + local lxi + while true + lxi = size(X,xi) + xi += 1 + if lxi != 1 || xi > ndx + break + end + end + if xi > ndx + trailing = lii + while ii <= li + lii = length(I[ii])::Int + trailing *= lii + ii += 1 + end + # X's last dimension can match all the trailing indexes + if lxi == trailing && match + return + else + throw(DimensionMismatch("")) + end + else + if lxi != lii + match = false + end end end end + + # last index can match X's trailing dimensions + lii = length(I[ii])::Int + nel *= lii + if lii != trailingsize(X,xi) + match = false + end + + if !(match && length(X)==nel) + throw(DimensionMismatch("")) + end +end + +setindex_shape_check(X::AbstractArray) = (length(X)==1 || throw(DimensionMismatch(""))) + +setindex_shape_check(X::AbstractArray, i) = + (length(X)==length(i) || throw(DimensionMismatch(""))) + +setindex_shape_check{T}(X::AbstractArray{T,1}, i) = + (length(X)==length(i) || throw(DimensionMismatch(""))) + +setindex_shape_check{T}(X::AbstractArray{T,1}, i, j) = + (length(X)==length(i)*length(j) || throw(DimensionMismatch(""))) + +function setindex_shape_check{T}(X::AbstractArray{T,2}, i, j) + li, lj = length(i), length(j) + if length(X) != li*lj + throw(DimensionMismatch("")) + end + sx1 = size(X,1) + if !(li == 1 || li == sx1 || sx1 == 1) + throw(DimensionMismatch("")) + end end # convert to integer index