-
-
Notifications
You must be signed in to change notification settings - Fork 5.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A[c|t]_mul_B[!] specializations for SparseMatrixCSC-StridedVecOrMat, less generalized linear indexing and meta-fu, take 2 #20053
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -43,54 +43,77 @@ end | |
|
||
# In matrix-vector multiplication, the correct orientation of the vector is assumed. | ||
|
||
for (f, op, transp) in ((:A_mul_B, :identity, false), | ||
(:Ac_mul_B, :ctranspose, true), | ||
(:At_mul_B, :transpose, true)) | ||
@eval begin | ||
function $(Symbol(f,:!))(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) | ||
if $transp | ||
A.n == size(C, 1) || throw(DimensionMismatch()) | ||
A.m == size(B, 1) || throw(DimensionMismatch()) | ||
else | ||
A.n == size(B, 1) || throw(DimensionMismatch()) | ||
A.m == size(C, 1) || throw(DimensionMismatch()) | ||
end | ||
size(B, 2) == size(C, 2) || throw(DimensionMismatch()) | ||
nzv = A.nzval | ||
rv = A.rowval | ||
if β != 1 | ||
β != 0 ? scale!(C, β) : fill!(C, zero(eltype(C))) | ||
end | ||
for col = 1:A.n | ||
for k = 1:size(C, 2) | ||
if $transp | ||
tmp = zero(eltype(C)) | ||
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1) | ||
tmp += $(op)(nzv[j])*B[rv[j],k] | ||
end | ||
C[col,k] += α*tmp | ||
else | ||
αxj = α*B[col,k] | ||
@inbounds for j = A.colptr[col]:(A.colptr[col + 1] - 1) | ||
C[rv[j], k] += nzv[j]*αxj | ||
end | ||
end | ||
end | ||
end | ||
C | ||
end | ||
A_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(_argstuple_AqmulB(A, B)...) | ||
At_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = At_mul_B!(_argstuple_AqmulB(A, B)...) | ||
Ac_mul_B(A::SparseMatrixCSC, B::StridedVecOrMat) = Ac_mul_B!(_argstuple_AqmulB(A, B)...) | ||
_argstuple_AqmulB{TvA,TB}(A::SparseMatrixCSC{TvA}, b::StridedVector{TB}) = | ||
(R = promote_type(TvA, TB); (one(R), A, b, zero(R), similar(b, R, A.n))) | ||
_argstuple_AqmulB{TvA,TB}(A::SparseMatrixCSC{TvA}, B::StridedMatrix{TB}) = | ||
(R = promote_type(TvA, TB); (one(R), A, B, zero(R), similar(B, R, (A.n, size(B,2))))) | ||
|
||
A_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = | ||
_Aq_mul_B!(α, A, identity, B, β, C) | ||
At_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = | ||
_Aq_mul_B!(α, A, transpose, B, β, C) | ||
Ac_mul_B!(α::Number, A::SparseMatrixCSC, B::StridedVecOrMat, β::Number, C::StridedVecOrMat) = | ||
_Aq_mul_B!(α, A, ctranspose, B, β, C) | ||
|
||
function _Aq_mul_B!(α::Number, A::SparseMatrixCSC, transopA::Function, | ||
B::StridedVecOrMat, β::Number, C::StridedVecOrMat) | ||
_AqmulB_checkshapecompat(A, transopA, B, C) | ||
_AqmulB_specialscale!(C, β) | ||
_AqmulB_kernel!(α, A, transopA, B, C) | ||
return C | ||
end | ||
|
||
function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, x::StridedVector{Tx}) | ||
T = promote_type(TA, Tx) | ||
$(Symbol(f,:!))(one(T), A, x, zero(T), similar(x, T, A.n)) | ||
qtransposefntype = Union{typeof(transpose), typeof(ctranspose)} | ||
_AqmulB_checkshapecompat(A, ::typeof(identity), B, C) = _AqmulB_checkshapecompat(A.m, A.n, B, C) | ||
_AqmulB_checkshapecompat(A, ::qtransposefntype, B, C) = _AqmulB_checkshapecompat(A.n, A.m, B, C) | ||
function _AqmulB_checkshapecompat(Aqm, Aqn, B, C) | ||
size(B, 1) == Aqn || throw(DimensionMismatch()) | ||
size(C, 1) == Aqm || throw(DimensionMismatch()) | ||
size(B, 2) == size(C, 2) || throw(DimensionMismatch()) | ||
end | ||
|
||
_AqmulB_specialscale!(C::StridedVecOrMat, β::Number) = | ||
β == 1 || (β == 0 ? fill!(C, zero(eltype(C))) : scale!(C, β)) | ||
|
||
function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedVector, C::StridedVector) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. likewise, julian argument order would have the output array first There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this should be discussed elsewhere. It is not really relevant to this PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we're adding new functions here, they should make sense There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Agreed (edit: with op). When changing the user-facing methods' argument order, changing the kernel's argument order to match would be great. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To clarify, I agree that changing the argument order of these five-argument There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The kernel has a different set of inputs than the outer function so I'm less attached to the kernel needing to have the same argument order. But sure, this can wait - it should be refactored when a matrix transpose type allows cleaning up all these names, we can also fix the argument order and drop the Fortran conventions once dispatch is doing most of the work. I am a bit curious how widely used the gemm style argument orders are for these functions (and why they aren't being called with the gemm name in those cases). If the cost of keyword arguments gets solved soon, it would probably make more sense to do the scalar multiples that way, maybe with more obvious names even. This is post-0.6 brainstorming though. Given mbauman's reworking of the generalized linear indexing to possibly not deprecate indexing with trailing 1's yet, would these PR's still be relevant? |
||
for colA in 1:A.n | ||
αBforcolA = α * B[colA] | ||
@inbounds for indA in nzrange(A, colA) | ||
C[A.rowval[indA]] += A.nzval[indA] * αBforcolA | ||
end | ||
end | ||
end | ||
function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, ::typeof(identity), B::StridedMatrix, C::StridedMatrix) | ||
for colA in 1:A.n, colC in 1:size(C, 2) | ||
αBforcolAcolC = α * B[colA, colC] | ||
@inbounds for indA in nzrange(A, colA) | ||
C[A.rowval[indA], colC] += A.nzval[indA] * αBforcolAcolC | ||
end | ||
function $(f){TA,S,Tx}(A::SparseMatrixCSC{TA,S}, B::StridedMatrix{Tx}) | ||
T = promote_type(TA, Tx) | ||
$(Symbol(f,:!))(one(T), A, B, zero(T), similar(B, T, (A.n, size(B, 2)))) | ||
end | ||
end | ||
function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedVector, C::StridedVector) | ||
for colA in 1:A.n | ||
accumulator = zero(eltype(C)) | ||
@inbounds for indA in nzrange(A, colA) | ||
accumulator += op(A.nzval[indA]) * B[A.rowval[indA]] | ||
end | ||
C[colA] += α * accumulator | ||
end | ||
end | ||
function _AqmulB_kernel!(α::Number, A::SparseMatrixCSC, op::qtransposefntype, B::StridedMatrix, C::StridedMatrix) | ||
for colA in 1:A.n, colC in 1:size(C, 2) | ||
accumulator = zero(eltype(C)) | ||
@inbounds for indA in nzrange(A, colA) | ||
accumulator += op(A.nzval[indA]) * B[A.rowval[indA], colC] | ||
end | ||
C[colA, colC] += α * accumulator | ||
end | ||
end | ||
|
||
|
||
# For compatibility with dense multiplication API. Should be deleted when dense multiplication | ||
# API is updated to follow BLAS API. | ||
A_mul_B!(C::StridedVecOrMat, A::SparseMatrixCSC, B::StridedVecOrMat) = A_mul_B!(one(eltype(B)), A, B, zero(eltype(C)), C) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should really deprecate this argument order, it's not called gemm so there's no reason for it to follow fortran conventions instead of julia ones
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Referencing #16772 as a bread crumb to this case. Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
related but not quite identical