Skip to content

Commit

Permalink
implement new array assignment shape matching rule. fixes #4048, fixes
Browse files Browse the repository at this point in the history
…#4383

this rule ignores singleton dimensions, and allows the last dimension of
one side to match all trailing dimensions of the other.
  • Loading branch information
JeffBezanson committed Dec 28, 2013
1 parent 851b2c8 commit 99b3c8e
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 24 deletions.
15 changes: 1 addition & 14 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -593,20 +593,7 @@ function setindex!(A::Array, x, I::Union(Real,AbstractArray)...)
assign_cache = Dict()
end
X = x
nel = 1
for idx in I
nel *= length(idx)
end
if length(X) != nel
throw(DimensionMismatch(""))
end
if ndims(X) > 1
for i = 1:length(I)
if size(X,i) != length(I[i])
throw(DimensionMismatch(""))
end
end
end
setindex_shape_check(X, I...)
gen_array_index_map(assign_cache, storeind -> quote
A[$storeind] = X[refind]
refind += 1
Expand Down
82 changes: 72 additions & 10 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,21 +206,83 @@ index_shape(I::Real...) = ()
index_shape(i, I...) = tuple(length(i), index_shape(I...)...)

# check for valid sizes in A[I...] = X where X <: AbstractArray
# we want to allow dimensions that are equal up to permutation, but only
# for permutations that leave array elements in the same linear order.
# those are the permutations that preserve the order of the non-singleton
# dimensions.
function setindex_shape_check(X::AbstractArray, I...)
li = length(I)
ii = 1
nel = 1
for idx in I
nel *= length(idx)
end
if length(X) != nel
error("dimensions must match")
end
if ndims(X) > 1
for i = 1:length(I)
if size(X,i) != length(I[i])
error("dimensions must match")
xi = 1
ndx = ndims(X)
match = true
while ii < li
lii = length(I[ii])::Int
ii += 1
if lii != 1
nel *= lii
local lxi
while true
lxi = size(X,xi)
xi += 1
if lxi != 1 || xi > ndx
break
end
end
if xi > ndx
trailing = lii
while ii <= li
lii = length(I[ii])::Int
trailing *= lii
ii += 1
end
# X's last dimension can match all the trailing indexes
if lxi == trailing && match
return
else
throw(DimensionMismatch(""))
end
else
if lxi != lii
match = false
end
end
end
end

# last index can match X's trailing dimensions
lii = length(I[ii])::Int
nel *= lii
if lii != trailingsize(X,xi)
match = false
end

if !(match && length(X)==nel)
throw(DimensionMismatch(""))
end
end

setindex_shape_check(X::AbstractArray) = (length(X)==1 || throw(DimensionMismatch("")))

setindex_shape_check(X::AbstractArray, i) =
(length(X)==length(i) || throw(DimensionMismatch("")))

setindex_shape_check{T}(X::AbstractArray{T,1}, i) =
(length(X)==length(i) || throw(DimensionMismatch("")))

setindex_shape_check{T}(X::AbstractArray{T,1}, i, j) =
(length(X)==length(i)*length(j) || throw(DimensionMismatch("")))

function setindex_shape_check{T}(X::AbstractArray{T,2}, i, j)
li, lj = length(i), length(j)
if length(X) != li*lj
throw(DimensionMismatch(""))
end
sx1 = size(X,1)
if !(li == 1 || li == sx1 || sx1 == 1)
throw(DimensionMismatch(""))
end
end

# convert to integer index
Expand Down

0 comments on commit 99b3c8e

Please sign in to comment.