Skip to content

Commit

Permalink
optimize transposed matrix multiply. closes issue #313.
Browse files Browse the repository at this point in the history
so far I've only implemented vector*vector and matrix*matrix cases.
not sure how best to deal with 2x2 and 3x3, but this should be good enough.
  • Loading branch information
JeffBezanson committed Dec 15, 2011
1 parent 33abec4 commit 361588d
Show file tree
Hide file tree
Showing 5 changed files with 145 additions and 24 deletions.
62 changes: 50 additions & 12 deletions j/linalg.j
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
## linalg.j: Basic Linear Algebra functions ##

aCb(x::AbstractVector, y::AbstractVector) = dot(x, y)
aTb{T<:Real}(x::AbstractVector{T}, y::AbstractVector{T}) = dot(x, y)

function dot(x::AbstractVector, y::AbstractVector)
s = zero(eltype(x))
for i=1:length(x)
Expand All @@ -17,6 +20,7 @@ cross(a::AbstractVector, b::AbstractVector) =
# TODO: It will be faster for large matrices to convert to float,
# call BLAS, and convert back to required type.

# TODO: support transposed arguments
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractVector{S})
mA = size(A, 1)
mB = size(B, 1)
Expand All @@ -30,6 +34,7 @@ function (*){T,S}(A::AbstractMatrix{T}, B::AbstractVector{S})
return C
end

# TODO: support transposed arguments
function (*){T,S}(A::AbstractVector{S}, B::AbstractMatrix{T})
nA = size(A, 1)
nB = size(B, 2)
Expand All @@ -45,11 +50,12 @@ function (*){T,S}(A::AbstractVector{S}, B::AbstractMatrix{T})
return C
end

# TODO: support transposed arguments
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
(mA, nA) = size(A)
(mB, nB) = size(B)
if mA == 2 && nA == 2 && nB == 2; return matmul2x2(A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(A,B); end
if mA == 2 && nA == 2 && nB == 2; return matmul2x2('N','N',A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3('N','N',A,B); end
C = zeros(promote_type(T,S), mA, nB)
z = zero(eltype(C))

Expand Down Expand Up @@ -78,12 +84,24 @@ function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
end

# multiply 2x2 matrices
function matmul2x2{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
R = promote_type(T,S)
C = Array(R, 2, 2)

A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
B11 = B[1,1]; B12 = B[1,2]; B21 = B[2,1]; B22 = B[2,2]
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A21 = A[1,2]; A22 = A[2,2]
elseif tA == 'C'
A11 = conj(A[1,1]); A12 = conj(A[2,1]); A21 = conj(A[1,2]); A22 = conj(A[2,2])
else
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
end
if tB == 'T'
B11 = B[1,1]; B12 = B[2,1]; B21 = B[1,2]; B22 = B[2,2]
elseif tB == 'C'
B11 = conj(B[1,1]); B12 = conj(B[2,1]); B21 = conj(B[1,2]); B22 = conj(B[2,2])
else
B11 = B[1,1]; B12 = B[1,2]; B21 = B[2,1]; B22 = B[2,2]
end

C[1,1] = A11*B11 + A12*B21
C[1,2] = A11*B12 + A12*B22
Expand All @@ -93,17 +111,37 @@ function matmul2x2{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
return C
end

function matmul3x3{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
R = promote_type(T,S)
C = Array(R, 3, 3)

A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3];
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3];
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3];
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A13 = A[3,1];
A21 = A[1,2]; A22 = A[2,2]; A23 = A[3,2];
A31 = A[1,3]; A32 = A[2,3]; A33 = A[3,3];
elseif tA == 'C'
A11 = conj(A[1,1]); A12 = conj(A[2,1]); A13 = conj(A[3,1]);
A21 = conj(A[1,2]); A22 = conj(A[2,2]); A23 = conj(A[3,2]);
A31 = conj(A[1,3]); A32 = conj(A[2,3]); A33 = conj(A[3,3]);
else
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3];
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3];
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3];
end

B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3];
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3];
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3];
if tB == 'T'
B11 = B[1,1]; B12 = B[2,1]; B13 = B[3,1];
B21 = B[1,2]; B22 = B[2,2]; B23 = B[3,2];
B31 = B[1,3]; B32 = B[2,3]; B33 = B[3,3];
elseif tB == 'C'
B11 = conj(B[1,1]); B12 = conj(B[2,1]); B13 = conj(B[3,1]);
B21 = conj(B[1,2]); B22 = conj(B[2,2]); B23 = conj(B[3,2]);
B31 = conj(B[1,3]); B32 = conj(B[2,3]); B33 = conj(B[3,3]);
else
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3];
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3];
B31 = B[3,1]; B32 = B[3,2]; B33 = B[3,3];
end

C[1,1] = A11*B11 + A12*B21 + A13*B31
C[1,2] = A11*B12 + A12*B22 + A13*B32
Expand Down
76 changes: 66 additions & 10 deletions j/linalg_blas.j
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ macro _jl_blas_gemm_macro(fname, eltype)
Ptr{$eltype}, Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{Int32},
Ptr{$eltype}, Ptr{$eltype}, Ptr{Int32}),
transA, transB, int32(m), int32(n), int32(k),
uint8(transA), uint8(transB), int32(m), int32(n), int32(k),
alpha, a, int32(lda),
b, int32(ldb),
beta, c, int32(ldc))
Expand All @@ -149,25 +149,79 @@ end

function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
(mA, nA) = size(A)
(mB, nB) = size(B)
_jl_gemm('N', 'N', A, B)
end

function aTb{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('T', 'N', A, B)
end

function abT{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('N', 'T', A, B)
end

function aTbT{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('T', 'T', A, B)
end

function aCb{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('C', 'N', A, B)
end

function abC{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('N', 'C', A, B)
end

function aCbC{T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
B::StridedMatrix{T})
_jl_gemm('C', 'C', A, B)
end

function _jl_gemm{T<:Union(Float64,Float32,Complex128,Complex64)}(tA, tB,
A::StridedMatrix{T},
B::StridedMatrix{T})
if tA != 'N'
(nA, mA) = size(A)
else
(mA, nA) = size(A)
end
if tB != 'N'
(nB, mB) = size(B)
else
(mB, nB) = size(B)
end

if nA != mB; error("*: argument shapes do not match"); end

if mA == 2 && nA == 2 && nB == 2; return matmul2x2(A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(A,B); end
if mA == 2 && nA == 2 && nB == 2; return matmul2x2(tA,tB,A,B); end
if mA == 3 && nA == 3 && nB == 3; return matmul3x3(tA,tB,A,B); end

if stride(A, 1) != 1 || stride(B, 1) != 1
return invoke(*, (AbstractArray, AbstractArray), A, B)
if tA == 'T'
A = A.'
elseif tA == 'C'
A = A'
end
if tB == 'T'
B = B.'
elseif tB == 'C'
B = B'
end
return invoke(*, (AbstractMatrix, AbstractMatrix), A, B)
end

# Result array does not need to be initialized as long as beta==0
C = Array(T, mA, nB)

_jl_blas_gemm("N", "N", mA, nB, nA,
one(T), A, stride(A, 2),
B, stride(B, 2),
zero(T), C, mA)
_jl_blas_gemm(tA, tB, mA, nB, nA,
one(T), A, stride(A, 2),
B, stride(B, 2),
zero(T), C, mA)
return C
end

Expand Down Expand Up @@ -209,6 +263,7 @@ end
@_jl_blas_gemv_macro :zgemv_ Complex128
@_jl_blas_gemv_macro :cgemv_ Complex64

# TODO: support transposed arguments
function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T},
X::StridedVector{T})
(mA, nA) = size(A)
Expand All @@ -230,6 +285,7 @@ function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(A::StridedMatrix{T}
return Y
end

# TODO: support transposed arguments
function (*){T<:Union(Float64,Float32,Complex128,Complex64)}(X::StridedVector{T},
A::StridedMatrix{T})
nX = size(X, 1)
Expand Down
8 changes: 8 additions & 0 deletions j/operators.j
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@ mod1{T<:Real}(x::T, y::T) = y-mod(y-x,y)
# cmp returns -1, 0, +1 indicating ordering
cmp{T<:Real}(x::T, y::T) = sign(x-y)

# transposed multiply
aCb(a, b) = ctranspose(a)*b
abC(a, b) = a*ctranspose(b)
aCbC(a, b) = ctranspose(a)*ctranspose(b)
aTb(a, b) = transpose(a)*b
abT(a, b) = a*transpose(b)
aTbT(a, b) = transpose(a)*transpose(b)

oftype{T}(::Type{T},c) = convert(T,c)
oftype{T}(x::T,c) = convert(T,c)

Expand Down
19 changes: 19 additions & 0 deletions src/julia-syntax.scm
Original file line number Diff line number Diff line change
Expand Up @@ -909,6 +909,25 @@
(pattern-lambda (|'| a) `(call ctranspose ,a))
(pattern-lambda (|.'| a) `(call transpose ,a))

;; transposed multiply
(pattern-lambda (call (-/ *) (|'| a) b)
`(call aCb ,a ,b))

(pattern-lambda (call (-/ *) a (|'| b))
`(call abC ,a ,b))

(pattern-lambda (call (-/ *) (|'| a) (|'| b))
`(call aCbC ,a ,b))

(pattern-lambda (call (-/ *) (|.'| a) b)
`(call aTb ,a ,b))

(pattern-lambda (call (-/ *) a (|.'| b))
`(call abT ,a ,b))

(pattern-lambda (call (-/ *) (|.'| a) (|.'| b))
`(call aTbT ,a ,b))

)) ; patterns

; patterns that verify all syntactic sugar was well-formed
Expand Down
4 changes: 2 additions & 2 deletions test/bench/hpl_par.j
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ function trailing_update(L_II, A_IJ, A_KI, A_KJ, row_dep, col_dep)
if !isempty(A_KJ)
m, k = size(A_KI)
n = size(A_IJ,2)
_jl_blas_gemm("N","N",m,n,k,-1.0,A_KI,m,A_IJ,k,1.0,A_KJ,m)
_jl_blas_gemm('N','N',m,n,k,-1.0,A_KI,m,A_IJ,k,1.0,A_KJ,m)
#A_KJ = A_KJ - A_KI*A_IJ
end

Expand Down Expand Up @@ -268,7 +268,7 @@ function trailing_update2(C, L_II, C_KI, i, j, n, flag, dep)
if !isempty(C_KJ)
cm, ck = size(C_KI)
cn = size(C_IJ,2)
_jl_blas_gemm("N","N",cm,cn,ck,-1.0,C_KI,cm,C_IJ,ck,1.0,C_KJ,cm)
_jl_blas_gemm('N','N',cm,cn,ck,-1.0,C_KI,cm,C_IJ,ck,1.0,C_KJ,cm)
#C_KJ = C_KJ - C_KI*C_IJ
C[K,J] = C_KJ
end
Expand Down

1 comment on commit 361588d

@ViralBShah
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we prefix these operator names with something like jl_matmul?

Please sign in to comment.