From 4ae3897105ae1250849dac66b05f8da5b0f444fb Mon Sep 17 00:00:00 2001 From: Sacha Verweij Date: Sat, 19 Nov 2016 12:20:33 -0800 Subject: [PATCH] Clean up broadcast[!] over pairs of sparse matrices, and fix result where the broadcast function does not return zero when both arguments are zero. --- base/sparse/sparsematrix.jl | 427 +++++++++++++++++++++++------------- test/sparse/sparse.jl | 25 +++ 2 files changed, 300 insertions(+), 152 deletions(-) diff --git a/base/sparse/sparsematrix.jl b/base/sparse/sparsematrix.jl index 10b237db9769f..bae755a032bf3 100644 --- a/base/sparse/sparsematrix.jl +++ b/base/sparse/sparsematrix.jl @@ -1474,148 +1474,280 @@ round{To}(::Type{To}, A::SparseMatrixCSC) = round.(To, A) # TODO: More appropriate location? conj!(A::SparseMatrixCSC) = (broadcast!(conj, A.nzval, A.nzval); A) - -## Broadcasting kernels specialized for returning a SparseMatrixCSC - -# Operations with zero result if both operands are zero -function gen_broadcast_body_sparse(f::Function, is_first_sparse::Bool) - F = Expr(:quote, f) - quote - Base.Broadcast.check_broadcast_indices(indices(B), A_1) - Base.Broadcast.check_broadcast_indices(indices(B), A_2) - - colptrB = B.colptr; rowvalB = B.rowval; nzvalB = B.nzval - colptr1 = A_1.colptr; rowval1 = A_1.rowval; nzval1 = A_1.nzval - colptr2 = A_2.colptr; rowval2 = A_2.rowval; nzval2 = A_2.nzval - - nnzB = isempty(B) ? 0 : (nnz(A_1) * div(B.n, A_1.n) * div(B.m, A_1.m) + - nnz(A_2) * div(B.n, A_2.n) * div(B.m, A_2.m)) - if length(rowvalB) < nnzB - resize!(rowvalB, nnzB) - end - if length(nzvalB) < nnzB - resize!(nzvalB, nnzB) - end - z = zero(Tv) - - ptrB = 1 - colptrB[1] = 1 - - Tr1 = eltype(rowval1) - Tr2 = eltype(rowval2) - - @inbounds for col = 1:B.n - ptr1::Int = A_1.n == 1 ? colptr1[1] : colptr1[col] - stop1::Int = A_1.n == 1 ? colptr1[2] : colptr1[col+1] - ptr2::Int = A_2.n == 1 ? colptr2[1] : colptr2[col] - stop2::Int = A_2.n == 1 ? colptr2[2] : colptr2[col+1] - - if A_1.m == A_2.m || (A_1.m == 1 && ptr1 == stop1) || (A_2.m == 1 && ptr2 == stop2) - while ptr1 < stop1 && ptr2 < stop2 - row1 = rowval1[ptr1] - row2 = rowval2[ptr2] - if row1 < row2 - res = ($F)(nzval1[ptr1], z) - if res != z - rowvalB[ptrB] = row1 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr1 += 1 - elseif row2 < row1 - res = ($F)(z, nzval2[ptr2]) - if res != z - rowvalB[ptrB] = row2 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr2 += 1 - else - res = ($F)(nzval1[ptr1], nzval2[ptr2]) - if res != z - rowvalB[ptrB] = row1 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr1 += 1 - ptr2 += 1 - end +## Broadcast operations involving two sparse matrices +function broadcast{Tf}(f::Tf, A::SparseMatrixCSC, B::SparseMatrixCSC) + indextypeC = promote_type(eltype(A.rowval), eltype(B.rowval)) + entrytypeC = Base.promote_eltype_op(f, A, B) + shapeC = Base.to_shape(Base.Broadcast.broadcast_indices(A, B)) + C = spzeros(entrytypeC, indextypeC, shapeC) + return _broadcast_nodimscheck!(f, C, A, B) +end +function broadcast!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC) + Base.Broadcast.check_broadcast_indices(indices(C), A) + Base.Broadcast.check_broadcast_indices(indices(C), B) + return _broadcast_nodimscheck!(f, C, A, B) +end +function _broadcast_nodimscheck!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC) + # Check whether f(0...) yields zero, and branch appropriately + fofzeros = f(zero(eltype(A)), zero(eltype(B))) + fpreszeros = fofzeros == zero(fofzeros) + return fpreszeros ? _broadcast_binzeropres!(f, C, A, B) : + _broadcast_notbinzeropres!(f, fofzeros, C, A, B) +end +# TODO: _broadcast_binzeropres! and _broadcast_notbinzeropres! could be more efficient and +# clearer, possibly at the cost of additional code. Consider another rewrite. +function _broadcast_binzeropres!{Tf}(f::Tf, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC) + # Calculate upper bound on number of entries in C + isempty(C) && return C + maxnnzfromA = nnz(A) * div(C.n, A.n) * div(C.m, A.m) + maxnnzfromB = nnz(B) * div(C.n, B.n) * div(C.m, B.m) + maxnnzfromAB = maxnnzfromA + maxnnzfromB + nnzC = maxnnzfromAB + # Resize C to accomodate max number of entries in C + length(C.rowval) < nnzC && resize!(C.rowval, nnzC) + length(C.nzval) < nnzC && resize!(C.nzval, nnzC) + + # Populate C column by column + Cptr = 1 + C.colptr[1] = 1 + @inbounds for j in 1:C.n + # Determine rowval/nzval ranges in A and B corresponding to C's jth column + # If A(/B) has only one column, then A(/B)'s single column corresponds to C's jth column + # If A(/B) has more than one column, then A(/B)'s jth column corresponds to C's jth column + Aptr, Astop = A.n == 1 ? (A.colptr[1], A.colptr[2]) : (A.colptr[j], A.colptr[j + 1]) + Bptr, Bstop = B.n == 1 ? (B.colptr[1], B.colptr[2]) : (B.colptr[j], B.colptr[j + 1]) + + # The following conditional chain separates column-pairs into three cases: (1) the + # columns have the same number of rows, or either or both columns have only one row + # and contain no stored entries; (2) A has more than one row, B has only one row, + # and B's column has a stored entry in that one row; (3) B has more than oen row, + # A has only one row, and A's column has a stored entry in that one row. + + if A.m == B.m || (A.m == 1 && Aptr == Astop) || (B.m == 1 && Bptr == Bstop) + # Case (1): the columns have the same number of rows, or either or both + # columns have only one row and contain no stored entries. + # + # If both columns contain stored entries, then sweep (in step) through those + # stored entries till exhaustion of either of the columns' stored entries. + # For each stored entry / entry-pair encountered, compute and store the + # appropriate row-value pair in C's jth column. + while Aptr < Astop && Bptr < Bstop + rowA = A.rowval[Aptr] + rowB = B.rowval[Bptr] + if rowA < rowB # an entry stored in A but not in B + valC = f(A.nzval[Aptr], zero(eltype(C))) + rowC::eltype(C.rowval) = rowA + Aptr += one(Aptr) + elseif rowB < rowA # an entry stored in B but not in A + valC = f(zero(eltype(C)), B.nzval[Bptr]) + rowC = rowB + Bptr += one(Bptr) + else # rowA == rowB, an entry stored in both A and B + valC = f(A.nzval[Aptr], B.nzval[Bptr]) + rowC = rowA + Aptr += one(Aptr) + Bptr += one(Bptr) end - - while ptr1 < stop1 - res = ($F)(nzval1[ptr1], z) - if res != z - row1 = rowval1[ptr1] - rowvalB[ptrB] = row1 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr1 += 1 + if valC != zero(eltype(C)) + C.rowval[Cptr] = rowC + C.nzval[Cptr] = valC + Cptr += 1 end - - while ptr2 < stop2 - res = ($F)(z, nzval2[ptr2]) - if res != z - row2 = rowval2[ptr2] - rowvalB[ptrB] = row2 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr2 += 1 + end + # If B's column had no stored entries, or we exhausted the stored entries in + # B's column without exhausting those in A's column, sweep over the (remaining) + # stored entries in A's column. For each such stored entry encountered, compute + # and store the appropriate row-value pair in C's jth column. + while Aptr < Astop + valC = f(A.nzval[Aptr], zero(eltype(B))) + if valC != zero(eltype(C)) + C.rowval[Cptr] = A.rowval[Aptr] + C.nzval[Cptr] = valC + Cptr += 1 end - elseif A_1.m != 1 # A_1.m != 1 && A_2.m == 1 - scalar2 = A_2.nzval[ptr2] - row1 = ptr1 < stop1 ? rowval1[ptr1] : -one(Tr1) - for row2 = one(Tr2):Tr2(B.m) - if ptr1 >= stop1 || row1 != row2 - res = ($F)(z, scalar2) - if res != z - rowvalB[ptrB] = row2 - nzvalB[ptrB] = res - ptrB += 1 - end - else - res = ($F)(nzval1[ptr1], scalar2) - if res != z - rowvalB[ptrB] = row1 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr1 += 1 - row1 = ptr1 < stop1 ? rowval1[ptr1] : -one(Tr1) - end + Aptr += one(Aptr) + end + # If A's column had no stored entries, or we exhausted the stored entries in + # A's column without exhausting those in B's column, sweep over the (remaining) + # stored entries in A's column. For each such stored entry encountered, compute + # and store the appropriate row-value pair in C's jth column. + while Bptr < Bstop + valC = f(zero(eltype(A)), B.nzval[Bptr]) + if valC != zero(eltype(C)) + C.rowval[Cptr] = B.rowval[Bptr] + C.nzval[Cptr] = valC + Cptr += 1 end - else # A_1.m == 1 && A_2.m != 1 - scalar1 = nzval1[ptr1] - row2 = ptr2 < stop2 ? rowval2[ptr2] : -one(Tr2) - for row1 = one(Tr1):Tr1(B.m) - if ptr2 >= stop2 || row1 != row2 - res = ($F)(scalar1, z) - if res != z - rowvalB[ptrB] = row1 - nzvalB[ptrB] = res - ptrB += 1 - end - else - res = ($F)(scalar1, nzval2[ptr2]) - if res != z - rowvalB[ptrB] = row2 - nzvalB[ptrB] = res - ptrB += 1 - end - ptr2 += 1 - row2 = ptr2 < stop2 ? rowval2[ptr2] : -one(Tr2) - end + Bptr += one(Bptr) + end + elseif A.m != 1 + # Case (2): A has more than one row, B has only a single row, and B's column + # has a stored entry in that row (A.m != 1 && B.m == 1 && (Bptr == Bstop - 1)). + # TODO: This could be substantially more efficient. + valB = B.nzval[Bptr] + rowA = Aptr < Astop ? A.rowval[Aptr] : -one(eltype(A.rowval)) + for rowB::eltype(B.rowval) in 1:C.m + if Aptr >= Astop || rowA != rowB + valC = f(zero(eltype(A)), valB) + rowC::eltype(C.rowval) = rowB + else + valC = f(A.nzval[Aptr], valB) + rowC = rowA + Aptr += one(Aptr) + rowA = Aptr < Astop ? A.rowval[Aptr] : -one(eltype(A.rowval)) + end + if valC != zero(eltype(C)) + C.rowval[Cptr] = rowC + C.nzval[Cptr] = valC + Cptr += 1 + end + end + else + # Case (3): B has more than one row, A has only a single row, and A's column + # has a stored entry in that row (B.m != 1 && A.m == 1 && (Aptr == Astop - 1)). + # TODO: This could be substantially more efficient. + valA = A.nzval[Aptr] + rowB = Bptr < Bstop ? B.rowval[Bptr] : -one(eltype(B.rowval)) + for rowA::eltype(A.rowval) in 1:C.m + if Bptr >= Bstop || rowA != rowB + valC = f(valA, zero(eltype(B))) + rowC::eltype(C.rowval) = rowA + else + valC = f(valA, B.nzval[Bptr]) + rowC = rowB + Bptr += one(Bptr) + rowB = Bptr < Bstop ? B.rowval[Bptr] : -one(eltype(B.rowval)) + end + if valC != zero(eltype(C)) + C.rowval[Cptr] = rowC + C.nzval[Cptr] = valC + Cptr += 1 end end - colptrB[col+1] = ptrB end - deleteat!(rowvalB, colptrB[end]:length(rowvalB)) - deleteat!(nzvalB, colptrB[end]:length(nzvalB)) - nothing + # Tie off the column pointers for C's jth column and proceed to the next column + C.colptr[j + 1] = Cptr + end + resize!(C.rowval, Cptr - 1) + resize!(C.nzval, Cptr - 1) + return C +end +function _broadcast_notbinzeropres!{Tf}(f::Tf, fillvalue, C::SparseMatrixCSC, A::SparseMatrixCSC, B::SparseMatrixCSC) + # Allocate storage for dense C + nnzC = C.m * C.n + resize!(C.rowval, nnzC) + resize!(C.nzval, nnzC) + # Build structure for dense C + copy!(C.colptr, 1:C.m:(nnzC + 1)) + for k in 1:C.m:(nnzC - C.m + 1) + copy!(C.rowval, k, 1:C.m) + end + + # Populate C with fillvalue, then column by column from A and B + fill!(C.nzval, fillvalue) + @inbounds for (j, jo) in zip(1:C.n, 0:C.m:(nnzC - 1)) + # Determine rowval/nzval ranges in A and B corresponding to C's jth column + # If A(/B) has only one column, then A(/B)'s single column corresponds to C's jth column + # If A(/B) has more than one column, then A(/B)'s jth column corresponds to C's jth column + Aptr, Astop = A.n == 1 ? (A.colptr[1], A.colptr[2]) : (A.colptr[j], A.colptr[j + 1]) + Bptr, Bstop = B.n == 1 ? (B.colptr[1], B.colptr[2]) : (B.colptr[j], B.colptr[j + 1]) + + # The following conditional chain separates column-pairs into three cases: (1) the + # columns have the same number of rows, or either or both columns have only one row + # and contain no stored entries; (2) A has more than one row, B has only one row, + # and B's column has a stored entry in that one row; (3) B has more than oen row, + # A has only one row, and A's column has a stored entry in that one row. + + if A.m == B.m || (A.m == 1 && Aptr == Astop) || (B.m == 1 && Bptr == Bstop) + # Case (1): the columns have the same number of rows, or either or both + # columns have only one row and contain no stored entries. + # + # If both columns contain stored entries, then sweep (in step) through those + # stored entries till exhaustion of either of the columns' stored entries. + # For each stored entry / entry-pair encountered, compute and store the + # appropriate value in C's jth column. + while Aptr < Astop && Bptr < Bstop + rowA = A.rowval[Aptr] + rowB = B.rowval[Bptr] + if rowA < rowB # an entry stored in A but not in B + valC = f(A.nzval[Aptr], zero(eltype(C))) + rowC::eltype(C.rowval) = rowA + Aptr += one(Aptr) + elseif rowB < rowA # an entry stored in B but not in A + valC = f(zero(eltype(C)), B.nzval[Bptr]) + rowC = rowB + Bptr += one(Bptr) + else # rowA == rowB, an entry stored in both A and B + valC = f(A.nzval[Aptr], B.nzval[Bptr]) + rowC = rowA + Aptr += one(Aptr) + Bptr += one(Aptr) + end + valC != fillvalue && (C.nzval[jo + rowC] = valC) + end + # If B's column had no stored entries, or we exhausted the stored entries in + # B's column without exhausting those in A's column, sweep over the (remaining) + # stored entries in A's column. For each such stored entry encountered, compute + # and store the appropriate value in C's jth column. + while Aptr < Astop + valC = f(A.nzval[Aptr], zero(eltype(C))) + valC != fillvalue && (C.nzval[jo + A.rowval[Aptr]] = valC) + Aptr += one(Aptr) + end + # If A's column had no stored entries, or we exhausted the stored entries in + # A's column without exhausting those in B's column, sweep over the (remaining) + # stored entries in A's column. For each such stored entry encountered, compute + # and store the appropriate value in C's jth column. + while Bptr < Bstop + valC = f(zero(eltype(C)), B.nzval[Bptr]) + valC != fillvalue && (C.nzval[jo + B.rowval[Bptr]] = valC) + Bptr += one(Bptr) + end + elseif A.m != 1 + # Case (2): A has more than one row, B has only a single row, and B's column + # has a stored entry in that row (A.m != 1 && B.m == 1 && (Bptr == Bstop - 1)). + # TODO: This could be substantially more efficient. + valB = B.nzval[Bptr] + rowA = Aptr < Astop ? A.rowval[Aptr] : -one(eltype(A.rowval)) + for rowB::eltype(B.rowval) in 1:C.m + if Aptr >= Astop || rowA != rowB + valC = f(zero(eltype(C)), valB) + rowC::eltype(C.rowval) = rowB + else + valC = f(A.nzval[Aptr], valB) + rowC = rowA + Aptr += one(Aptr) + rowA = Aptr < Astop ? A.rowval[Aptr] : -one(eltype(A.rowval)) + end + valC != fillvalue && (C.nzval[jo + rowC] = valC) + end + else + # Case (3): B has more than one row, A has only a single row, and A's column + # has a stored entry in that row (B.m != 1 && A.m == 1 && (Aptr == Astop - 1)). + # TODO: This could be substantially more efficient. + valA = A.nzval[Aptr] + rowB = Bptr < Bstop ? B.rowval[Bptr] : -one(eltype(B.rowval)) + for rowA::eltype(A.rowval) in 1:C.m + if Bptr >= Bstop || rowA != rowB + valC = f(valA, zero(eltype(C))) + rowC::eltype(C.rowval) = rowA + else + valC = f(valA, B.nzval[Bptr]) + rowC = rowB + Bptr += one(Bptr) + rowB = Bptr < Bstop ? B.rowval[Bptr] : -one(eltype(B.rowval)) + end + valC != fillvalue && (C.nzval[jo + rowC] = valC) + end + end end + return C end + +## Define unexported broadcast_zpreserving[!] methods +# TODO: Sort out what to do with broadcast_zpreserving and dependencies + function gen_broadcast_function_sparse(genbody::Function, f::Function, is_first_sparse::Bool) body = genbody(f, is_first_sparse) @eval let @@ -1695,7 +1827,6 @@ end for (Bsig, A1sig, A2sig, gbb, funcname) in ( - (SparseMatrixCSC , SparseMatrixCSC , SparseMatrixCSC, :gen_broadcast_body_sparse, :broadcast!), (SparseMatrixCSC , SparseMatrixCSC , Array, :gen_broadcast_body_zpreserving, :broadcast_zpreserving!), (SparseMatrixCSC , Array , SparseMatrixCSC, :gen_broadcast_body_zpreserving, :broadcast_zpreserving!), (SparseMatrixCSC , Number , SparseMatrixCSC, :gen_broadcast_body_zpreserving, :broadcast_zpreserving!), @@ -1713,10 +1844,6 @@ for (Bsig, A1sig, A2sig, gbb, funcname) in end # let broadcast_cache end - -broadcast{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = - broadcast!(f, spzeros(promote_type(Tv1, Tv2), promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))), A_1, A_2) - @inline broadcast_zpreserving!(args...) = broadcast!(args...) @inline broadcast_zpreserving(args...) = Base.Broadcast.broadcast_elwise_op(args...) broadcast_zpreserving{Tv1,Ti1,Tv2,Ti2}(f::Function, A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) = @@ -1729,21 +1856,17 @@ broadcast_zpreserving{Tv,Ti}(f::Function, A_1::Union{Array,BitArray,Number}, A_2 ## Binary arithmetic and boolean operators -for op in (+, -, min, max, &, |, xor) - body = gen_broadcast_body_sparse(op, true) - OP = Symbol(string(op)) - @eval begin - function ($OP){Tv1,Ti1,Tv2,Ti2}(A_1::SparseMatrixCSC{Tv1,Ti1}, A_2::SparseMatrixCSC{Tv2,Ti2}) - if size(A_1,1) != size(A_2,1) || size(A_1,2) != size(A_2,2) - throw(DimensionMismatch("")) - end - Tv = promote_op($op, Tv1, Tv2) - B = spzeros(Tv, promote_type(Ti1, Ti2), to_shape(broadcast_indices(A_1, A_2))) - $body - B - end - end -end # macro +# TODO: These seven functions should probably be reimplemented in terms of sparse map +# when a better sparse map exists. (And vectorized min, max, &, |, and xor should be +# deprecated in favor of compact-broadcast syntax.) +_checksameshape(A, B) = size(A) == size(B) || throw(DimensionMismatch("size(A) must match size(B)")) +(+)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(+, A, B)) +(-)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(-, A, B)) +min(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(min, A, B)) +max(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(max, A, B)) +(&)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(&, A, B)) +(|)(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(|, A, B)) +xor(A::SparseMatrixCSC, B::SparseMatrixCSC) = (_checksameshape(A, B); broadcast(xor, A, B)) (.+)(A::SparseMatrixCSC, B::Number) = Array(A) .+ B ( +)(A::SparseMatrixCSC, B::Array ) = Array(A) + B diff --git a/test/sparse/sparse.jl b/test/sparse/sparse.jl index 8ef00d0e8be57..46e1985e9ce48 100644 --- a/test/sparse/sparse.jl +++ b/test/sparse/sparse.jl @@ -1662,3 +1662,28 @@ end # 19304 @inferred hcat(sparse(rand(2,1)), eye(2,2)) + +# Test that broadcast[!](f, [C::SparseMatrixCSC], A::SparseMatrixCSC, B::SparseMatrixCSC) +# returns the correct (densely populated) result when f(zero(eltype(A)), zero(eltype(B))) != 0 +let + N = 5 + sparsesqrmat = sprand(N, N, 0.5) + sparsesqrmat2 = sprand(N, N, 0.5) + sparserowmat = sprand(1, N, 0.5) + sparsecolmat = sprand(N, 1, 0.5) + sparse1x1matz = spzeros(1, 1) + sparse1x1mato = spones(sparse1x1matz) + zeroscourge = (x, y) -> x + y + 1 + # test case where the matrices have the same shape and no singleton dimensions + @test broadcast(zeroscourge, sparsesqrmat, sparsesqrmat2) == + broadcast(zeroscourge, Matrix(sparsesqrmat), Matrix(sparsesqrmat2)) + # test combinations where either or both matrices have one or more singleton dimensions + sparsemats = (sparsesqrmat, sparserowmat, sparsecolmat, sparse1x1matz, sparse1x1mato) + densemats = map(Matrix, sparsemats) + for (sparseA, denseA) in zip(sparsemats, densemats) + for (sparseB, denseB) in zip(sparsemats, densemats) + @test broadcast(zeroscourge, sparseA, sparseB) == + broadcast(zeroscourge, denseA, denseB) + end + end +end