Skip to content

Commit

Permalink
Rewrite A[c|t]_mul_B[!] specializations for SparseMatrixCSC-StridedVe…
Browse files Browse the repository at this point in the history
…cOrMat combinations, without generalized linear indexing and meta-fu.
  • Loading branch information
Sacha0 committed Jan 16, 2017
1 parent ca3b06e commit 939ff69
Showing 1 changed file with 65 additions and 42 deletions.
107 changes: 65 additions & 42 deletions base/sparse/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,54 +43,77 @@ end

# In matrix-vector multiplication, the correct orientation of the vector is assumed.

for (f, op, transp) in ((:A_mul_B, :identity, false),
(:Ac_mul_B, :ctranspose, true),
(:At_mul_B, :transpose, true))
@eval begin
function $(Symbol(f,:!))(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
if $transp
A.n == size(C, 1) || throw(DimensionMismatch())
A.m == size(B, 1) || throw(DimensionMismatch())
else
A.n == size(B, 1) || throw(DimensionMismatch())
A.m == size(C, 1) || throw(DimensionMismatch())
end
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
nzv = A.nzval
rv = A.rowval
if β != 1
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C)))
end
for col = 1:A.n
for k = 1:size(C, 2)
if $transp
tmp = zero(eltype(C))
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
tmp += $(op)(nzv[j])*B[rv[j],k]
end
C[col,k] += α*tmp
else
αxj = α*B[col,k]
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1)
C[rv[j], k] += nzv[j]*αxj
end
end
end
end
C
end
A_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(_argstuple_AqmulB(A, B)...)
At_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(_argstuple_AqmulB(A, B)...)
Ac_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B!(_argstuple_AqmulB(A, B)...)
_argstuple_AqmulB{TvA,TB}(A::SparseMatrixCSC{TvA}, b::StridedVector{TB}) =
(R = promote_type(TvA, TB); (one(R), A, b, zero(R), similar(b, R, A.n)))
_argstuple_AqmulB{TvA,TB}(A::SparseMatrixCSC{TvA}, B::StridedMatrix{TB}) =
(R = promote_type(TvA, TB); (one(R), A, B, zero(R), similar(B, R, (A.n, size(B,2)))))

A_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
_Aq_mul_B!(α, A, identity, B, β, C)
At_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
_Aq_mul_B!(α, A, transpose, B, β, C)
Ac_mul_B!::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) =
_Aq_mul_B!(α, A, ctranspose, B, β, C)

function _Aq_mul_B!::Number, A::SparseMatrixCSC, transopA::Function,
B::StridedVecOrMat, β::Number, C::StridedVecOrMat)
_AqmulB_checkshapecompat(A, transopA, B, C)
_AqmulB_specialscale!(C, β)
_AqmulB_kernel!(α, A, transopA, B, C)
return C
end

function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx})
T = promote_type(TA, Tx)
$(Symbol(f,:!))(one(T), A, x, zero(T), similar(x, T, A.n))
qtransposefntype = Union{typeof(transpose), typeof(ctranspose)}
_AqmulB_checkshapecompat(A, ::typeof(identity), B, C) = _AqmulB_checkshapecompat(A.m, A.n, B, C)
_AqmulB_checkshapecompat(A, ::qtransposefntype, B, C) = _AqmulB_checkshapecompat(A.n, A.m, B, C)
function _AqmulB_checkshapecompat(Aqm, Aqn, B, C)
size(B, 1) == Aqn || throw(DimensionMismatch())
size(C, 1) == Aqm || throw(DimensionMismatch())
size(B, 2) == size(C, 2) || throw(DimensionMismatch())
end

_AqmulB_specialscale!(C::StridedVecOrMat, β::Number) =
β == 1 ||== 0 ? fill!(C, zero(eltype(C))) : scale!(C, β))

function _AqmulB_kernel!::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedVector, C::StridedVector)
for colA in 1:A.n
αBforcolA = α * B[colA]
@inbounds for indA in nzrange(A, colA)
C[A.rowval[indA]] += A.nzval[indA] * αBforcolA
end
end
end
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedMatrix, C::StridedMatrix)
for colA in 1:A.n, colC in 1:size(C, 2)
αBforcolAcolC = α * B[colA, colC]
@inbounds for indA in nzrange(A, colA)
C[A.rowval[indA], colC] += A.nzval[indA] * αBforcolAcolC
end
function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx})
T = promote_type(TA, Tx)
$(Symbol(f,:!))(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2))))
end
end
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedVector, C::StridedVector)
for colA in 1:A.n
accumulator = zero(eltype(C))
@inbounds for indA in nzrange(A, colA)
accumulator += op(A.nzval[indA]) * B[A.rowval[indA]]
end
C[colA] += α * accumulator
end
end
function _AqmulB_kernel!::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedMatrix, C::StridedMatrix)
for colA in 1:A.n, colC in 1:size(C, 2)
accumulator = zero(eltype(C))
@inbounds for indA in nzrange(A, colA)
accumulator += op(A.nzval[indA]) * B[A.rowval[indA], colC]
end
C[colA, colC] += α * accumulator
end
end


# For compatibility with dense multiplication API. Should be deleted when dense multiplication
# API is updated to follow BLAS API.
A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C)
Expand Down

0 comments on commit 939ff69

Please sign in to comment.