Skip to content

Commit

Permalink
Make (c)transpose work correctly for block matrices
Browse files Browse the repository at this point in the history
  • Loading branch information
andreasnoack committed Jun 17, 2014
1 parent c5107b3 commit ab795c8
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 43 deletions.
12 changes: 6 additions & 6 deletions base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,7 @@ function transpose!(B::StridedMatrix,A::StridedMatrix)
@inbounds begin
for j = 1:n
for i = 1:m
B[j,i] = A[i,j]
B[j,i] = transpose(A[i,j])
end
end
end
Expand All @@ -1249,7 +1249,7 @@ function transposeblock!(B::StridedMatrix,A::StridedMatrix,m::Int,n::Int,offseti
@inbounds begin
for j = offsetj+(1:n)
for i = offseti+(1:m)
B[j,i] = A[i,j]
B[j,i] = transpose(A[i,j])
end
end
end
Expand All @@ -1272,7 +1272,7 @@ function ctranspose!(B::StridedMatrix,A::StridedMatrix)
@inbounds begin
for j = 1:n
for i = 1:m
B[j,i] = conj(A[i,j])
B[j,i] = ctranspose(A[i,j])
end
end
end
Expand All @@ -1286,7 +1286,7 @@ function ctransposeblock!(B::StridedMatrix,A::StridedMatrix,m::Int,n::Int,offset
@inbounds begin
for j = offsetj+(1:n)
for i = offseti+(1:m)
B[j,i] = conj(A[i,j])
B[j,i] = ctranspose(A[i,j])
end
end
end
Expand All @@ -1312,8 +1312,8 @@ function ctranspose(A::StridedMatrix)
end
ctranspose{T<:Real}(A::StridedVecOrMat{T}) = transpose(A)

transpose(x::StridedVector) = [ x[j] for i=1, j=1:size(x,1) ]
ctranspose{T}(x::StridedVector{T}) = T[ conj(x[j]) for i=1, j=1:size(x,1) ]
transpose(x::StridedVector) = [ transpose(x[j]) for i=1, j=1:size(x,1) ]
ctranspose{T}(x::StridedVector{T}) = T[ ctranspose(x[j]) for i=1, j=1:size(x,1) ]

# set-like operators for vectors
# These are moderately efficient, preserve order, and remove dupes.
Expand Down
4 changes: 2 additions & 2 deletions base/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ function issym(A::AbstractMatrix)
m, n = size(A)
m==n || return false
for i = 1:(n-1), j = (i+1):n
if A[i,j] != A[j,i]
if A[i,j] != transpose(A[j,i])
return false
end
end
Expand All @@ -258,7 +258,7 @@ function ishermitian(A::AbstractMatrix)
m, n = size(A)
m==n || return false
for i = 1:n, j = i:n
if A[i,j] != conj(A[j,i])
if A[i,j] != ctranspose(A[j,i])
return false
end
end
Expand Down
70 changes: 35 additions & 35 deletions base/linalg/matmul.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function generic_matvecmul!{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T
aoffs = (k-1)*Astride
s = z
for i = 1:nA
s += A[aoffs+i] * B[i]
s += A[aoffs+i].'B[i]
end
C[k] = s
end
Expand All @@ -334,11 +334,11 @@ function generic_matvecmul!{T,S,R}(C::AbstractVector{R}, tA, A::AbstractMatrix{T
aoffs = (k-1)*Astride
s = z
for i = 1:nA
s += conj(A[aoffs+i]) * B[i]
s += A[aoffs+i]'B[i]
end
C[k] = s
end
else # tA == 'N'
else # tA == 'N'
fill!(C, z)
for k = 1:mB
aoffs = (k-1)*Astride
Expand Down Expand Up @@ -438,69 +438,69 @@ function generic_matmatmul!{T,S,R}(C::AbstractVecOrMat{R}, tA, tB, A::AbstractVe
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
Ctmp = A[i, 1]*B[j, 1]
Ctmp = A[i, 1]*B[j, 1].'
for k = 2:nA
Ctmp += A[i, k]*B[j, k]
Ctmp += A[i, k]*B[j, k].'
end
C[i,j] = Ctmp
end
else
for i = 1:mA, j = 1:nB
Ctmp = A[i, 1]*conj(B[j, 1])
Ctmp = A[i, 1]*B[j, 1]'
for k = 2:nA
Ctmp += A[i, k]*conj(B[j, k])
Ctmp += A[i, k]*B[j, k]'
end
C[i,j] = Ctmp
end
end
elseif tA == 'T'
if tB == 'N'
for i = 1:mA, j = 1:nB
Ctmp = A[1, i]*B[1, j]
Ctmp = A[1, i].'B[1, j]
for k = 2:nA
Ctmp += A[k, i]*B[k, j]
Ctmp += A[k, i].'B[k, j]
end
C[i,j] = Ctmp
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
Ctmp = A[1, i]*B[j, 1]
Ctmp = A[1, i].'B[j, 1].'
for k = 2:nA
Ctmp += A[k, i]*B[j, k]
Ctmp += A[k, i].'B[j, k].'
end
C[i,j] = Ctmp
end
else
for i = 1:mA, j = 1:nB
Ctmp = A[1, i]*conj(B[j, 1])
Ctmp = A[1, i].'B[j, 1]'
for k = 2:nA
Ctmp += A[k, i]*conj(B[j, k])
Ctmp += A[k, i].'B[j, k]'
end
C[i,j] = Ctmp
end
end
else
if tB == 'N'
for i = 1:mA, j = 1:nB
Ctmp = conj(A[1, i])*B[1, j]
Ctmp = A[1, i]'B[1, j]
for k = 2:nA
Ctmp += conj(A[k, i])*B[k, j]
Ctmp += A[k, i]'B[k, j]
end
C[i,j] = Ctmp
end
elseif tB == 'T'
for i = 1:mA, j = 1:nB
Ctmp = conj(A[1, i])*B[j, 1]
Ctmp = A[1, i]'B[j, 1].'
for k = 2:nA
Ctmp += conj(A[k, i])*B[j, k]
Ctmp += A[k, i]'B[j, k].'
end
C[i,j] = Ctmp
end
else
for i = 1:mA, j = 1:nB
Ctmp = conj(A[1, i])*conj(B[j, 1])
Ctmp = A[1, i]'B[j, 1]'
for k = 2:nA
Ctmp += conj(A[k, i])*conj(B[j, k])
Ctmp += A[k, i]'B[j, k]'
end
C[i,j] = Ctmp
end
Expand All @@ -519,16 +519,16 @@ end

function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A21 = A[1,2]; A22 = A[2,2]
A11 = transpose(A[1,1]); A12 = transpose(A[2,1]); A21 = transpose(A[1,2]); A22 = transpose(A[2,2])
elseif tA == 'C'
A11 = conj(A[1,1]); A12 = conj(A[2,1]); A21 = conj(A[1,2]); A22 = conj(A[2,2])
A11 = ctranspose(A[1,1]); A12 = ctranspose(A[2,1]); A21 = ctranspose(A[1,2]); A22 = ctranspose(A[2,2])
else
A11 = A[1,1]; A12 = A[1,2]; A21 = A[2,1]; A22 = A[2,2]
end
if tB == 'T'
B11 = B[1,1]; B12 = B[2,1]; B21 = B[1,2]; B22 = B[2,2]
B11 = transpose(B[1,1]); B12 = transpose(B[2,1]); B21 = transpose(B[1,2]); B22 = transpose(B[2,2])
elseif tB == 'C'
B11 = conj(B[1,1]); B12 = conj(B[2,1]); B21 = conj(B[1,2]); B22 = conj(B[2,2])
B11 = ctranspose(B[1,1]); B12 = ctranspose(B[2,1]); B21 = ctranspose(B[1,2]); B22 = ctranspose(B[2,2])
else
B11 = B[1,1]; B12 = B[1,2]; B21 = B[2,1]; B22 = B[2,2]
end
Expand All @@ -546,27 +546,27 @@ end

function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
if tA == 'T'
A11 = A[1,1]; A12 = A[2,1]; A13 = A[3,1];
A21 = A[1,2]; A22 = A[2,2]; A23 = A[3,2];
A31 = A[1,3]; A32 = A[2,3]; A33 = A[3,3];
A11 = transpose(A[1,1]); A12 = transpose(A[2,1]); A13 = transpose(A[3,1]);
A21 = transpose(A[1,2]); A22 = transpose(A[2,2]); A23 = transpose(A[3,2]);
A31 = transpose(A[1,3]); A32 = transpose(A[2,3]); A33 = transpose(A[3,3]);
elseif tA == 'C'
A11 = conj(A[1,1]); A12 = conj(A[2,1]); A13 = conj(A[3,1]);
A21 = conj(A[1,2]); A22 = conj(A[2,2]); A23 = conj(A[3,2]);
A31 = conj(A[1,3]); A32 = conj(A[2,3]); A33 = conj(A[3,3]);
A11 = ctranspose(A[1,1]); A12 = ctranspose(A[2,1]); A13 = ctranspose(A[3,1]);
A21 = ctranspose(A[1,2]); A22 = ctranspose(A[2,2]); A23 = ctranspose(A[3,2]);
A31 = ctranspose(A[1,3]); A32 = ctranspose(A[2,3]); A33 = ctranspose(A[3,3]);
else
A11 = A[1,1]; A12 = A[1,2]; A13 = A[1,3];
A21 = A[2,1]; A22 = A[2,2]; A23 = A[2,3];
A31 = A[3,1]; A32 = A[3,2]; A33 = A[3,3];
end

if tB == 'T'
B11 = B[1,1]; B12 = B[2,1]; B13 = B[3,1];
B21 = B[1,2]; B22 = B[2,2]; B23 = B[3,2];
B31 = B[1,3]; B32 = B[2,3]; B33 = B[3,3];
B11 = transpose(B[1,1]); B12 = transpose(B[2,1]); B13 = transpose(B[3,1]);
B21 = transpose(B[1,2]); B22 = transpose(B[2,2]); B23 = transpose(B[3,2]);
B31 = transpose(B[1,3]); B32 = transpose(B[2,3]); B33 = transpose(B[3,3]);
elseif tB == 'C'
B11 = conj(B[1,1]); B12 = conj(B[2,1]); B13 = conj(B[3,1]);
B21 = conj(B[1,2]); B22 = conj(B[2,2]); B23 = conj(B[3,2]);
B31 = conj(B[1,3]); B32 = conj(B[2,3]); B33 = conj(B[3,3]);
B11 = ctranspose(B[1,1]); B12 = ctranspose(B[2,1]); B13 = ctranspose(B[3,1]);
B21 = ctranspose(B[1,2]); B22 = ctranspose(B[2,2]); B23 = ctranspose(B[3,2]);
B31 = ctranspose(B[1,3]); B32 = ctranspose(B[2,3]); B33 = ctranspose(B[3,3]);
else
B11 = B[1,1]; B12 = B[1,2]; B13 = B[1,3];
B21 = B[2,1]; B22 = B[2,2]; B23 = B[2,3];
Expand Down
5 changes: 5 additions & 0 deletions base/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ mod1{T<:Real}(x::T, y::T) = y-mod(y-x,y)
rem1{T<:Real}(x::T, y::T) = rem(x-1,y)+1
fld1{T<:Real}(x::T, y::T) = fld(x-1,y)+1

# transpose
transpose(x) = x
ctranspose(x) = conj(transpose(x))
conj(x) = x

# transposed multiply
Ac_mul_B (a,b) = ctranspose(a)*b
A_mul_Bc (a,b) = a*ctranspose(b)
Expand Down

2 comments on commit ab795c8

@JeffBezanson
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just realized we should have a test case for this.

@andreasnoack
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done in 987eb73. Right now they only check that A.'A is symmetric and A'A is Hermitian because the generic matvecmul can't handle x'A'Ax right now.

Please sign in to comment.