diff --git a/base/sparse/linalg.jl b/base/sparse/linalg.jl index 1345c9fa848e90..64181d92df4dae 100644 --- a/base/sparse/linalg.jl +++ b/base/sparse/linalg.jl @@ -43,54 +43,77 @@ end # In matrix-vector multiplication, the correct orientation of the vector is assumed. -for (f, op, transp) in ((:A_mul_B, :identity, false), - (:Ac_mul_B, :ctranspose, true), - (:At_mul_B, :transpose, true)) - @eval begin - function $(Symbol(f,:!))(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) - if $transp - A.n == size(C, 1) || throw(DimensionMismatch()) - A.m == size(B, 1) || throw(DimensionMismatch()) - else - A.n == size(B, 1) || throw(DimensionMismatch()) - A.m == size(C, 1) || throw(DimensionMismatch()) - end - size(B, 2) == size(C, 2) || throw(DimensionMismatch()) - nzv = A.nzval - rv = A.rowval - if β != 1 - β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C))) - end - for col = 1:A.n - for k = 1:size(C, 2) - if $transp - tmp = zero(eltype(C)) - @inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1) - tmp += $(op)(nzv[j])*B[rv[j],k] - end - C[col,k] += α*tmp - else - αxj = α*B[col,k] - @inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1) - C[rv[j], k] += nzv[j]*αxj - end - end - end - end - C - end +A_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(_argstuple_AqmulB!(A, B)...) +At_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(_argstuple_AqmulB!(A, B)...) +Ac_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B!(_argstuple_AqmulB!(A, B)...) +_argstuple_AqmulB!{TvA,TB}(A::SparseMatrixCSC{TvA}, b::StridedVector{TB}) = + (R = promote_type(TvA, TB); (one(R), A, b, zero(R), similar(b, R, A.n))) +_argstuple_AqmulB!{TvA,TB}(A::SparseMatrixCSC{TvA}, B::StridedMatrix{TB}) = + (R = promote_type(TvA, TB); (one(R), A, B, zero(R), similar(B, R, (A.n, size(B,2))))) + +A_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = + _Aq_mul_B!(α, A, identity, B, β, C) +At_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = + _Aq_mul_B!(α, A, transpose, B, β, C) +Ac_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = + _Aq_mul_B!(α, A, ctranspose, B, β, C) + +function _Aq_mul_B!(α::Number, A::SparseMatrixCSC, transopA::Function, + B::StridedVecOrMat, β::Number, C::StridedVecOrMat) + _AqmulB_checkshapecompat(A, transopA, B, C) + _AqmulB_specialscale!(C, β) + _AqmulB_kernel!(α, A, transopA, B, C) + return C +end - function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) - T = promote_type(TA, Tx) - $(Symbol(f,:!))(one(T), A, x, zero(T), similar(x, T, A.n)) +qtransposefntype = Union{typeof(transpose), typeof(ctranspose)} +_AqmulB_checkshapecompat(A, ::typeof(identity), B, C) = _AqmulB_checkshapecompat(A.m, A.n, B, C) +_AqmulB_checkshapecompat(A, ::qtransposefntype, B, C) = _AqmulB_checkshapecompat(A.n, A.m, B, C) +function _AqmulB_checkshapecompat(Aqm, Aqn, B, C) + size(B, 1) == Aqn || throw(DimensionMismatch()) + size(C, 1) == Aqm || throw(DimensionMismatch()) + size(B, 2) == size(C, 2) || throw(DimensionMismatch()) +end + +_AqmulB_specialscale!(C::StridedVecOrMat, β::Number) = + β == 1 || (β == 0 ? fill!(C, zero(eltype(C))) : scale!(C, β)) + +function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedVector, C::StridedVector) + for colA in 1:A.n + αBforcolA = α * B[colA] + @inbounds for indA in nzrange(A, colA) + C[A.rowval[indA]] += A.nzval[indA] * αBforcolA + end + end +end +function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedMatrix, C::StridedMatrix) + for colA in 1:A.n, colC in 1:size(C, 2) + αBforcolAcolC = α * B[colA, colC] + @inbounds for indA in nzrange(A, colA) + C[A.rowval[indA], colC] += A.nzval[indA] * αBforcolAcolC end - function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) - T = promote_type(TA, Tx) - $(Symbol(f,:!))(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2)))) + end +end +function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedVector, C::StridedVector) + for colA in 1:A.n + accumulator = zero(eltype(C)) + @inbounds for indA in nzrange(A, colA) + accumulator += op(A.nzval[indA]) * B[A.rowval[indA]] + end + C[colA] += α * accumulator + end +end +function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedMatrix, C::StridedMatrix) + for colA in 1:A.n, colC in 1:size(C, 2) + accumulator = zero(eltype(C)) + @inbounds for indA in nzrange(A, colA) + accumulator += op(A.nzval[indA]) * B[A.rowval[indA], colC] end + C[colA, colC] += α * accumulator end end + # For compatibility with dense multiplication API. Should be deleted when dense multiplication # API is updated to follow BLAS API. A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)