Skip to content

Commit

Permalink
Add a GPU version of copytrito! (#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison authored Dec 11, 2023
1 parent 6becb4f commit effeef9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 8 deletions.
45 changes: 37 additions & 8 deletions src/host/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,16 @@ function Base.copyto!(A::Array{T,N}, B::Transpose{T, <:AbstractGPUArray{T,N}}) w
copyto!(A, Transpose(Array(parent(B))))
end


## copy upper triangle to lower and vice versa

function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, conjugate::Bool=false) where T
function LinearAlgebra.copytri!(A::AbstractGPUMatrix, uplo::AbstractChar, conjugate::Bool=false)
n = LinearAlgebra.checksquare(A)
if uplo == 'U' && conjugate
gpu_call(A) do ctx, _A
I = @cartesianidx _A
i, j = Tuple(I)
if j > i
_A[j,i] = conj(_A[i,j])
@inbounds _A[j,i] = conj(_A[i,j])
end
return
end
Expand All @@ -99,7 +98,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
I = @cartesianidx _A
i, j = Tuple(I)
if j > i
_A[j,i] = _A[i,j]
@inbounds _A[j,i] = _A[i,j]
end
return
end
Expand All @@ -108,7 +107,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
I = @cartesianidx _A
i, j = Tuple(I)
if j > i
_A[i,j] = conj(_A[j,i])
@inbounds _A[i,j] = conj(_A[j,i])
end
return
end
Expand All @@ -117,7 +116,7 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
I = @cartesianidx _A
i, j = Tuple(I)
if j > i
_A[i,j] = _A[j,i]
@inbounds _A[i,j] = _A[j,i]
end
return
end
Expand All @@ -127,6 +126,36 @@ function LinearAlgebra.copytri!(A::AbstractGPUMatrix{T}, uplo::AbstractChar, con
A
end

## copy a triangular part of a matrix to another matrix

if isdefined(LinearAlgebra, :copytrito!)
function LinearAlgebra.copytrito!(B::AbstractGPUMatrix, A::AbstractGPUMatrix, uplo::AbstractChar)
LinearAlgebra.BLAS.chkuplo(uplo)
m,n = size(A)
m1,n1 = size(B)
(m1 < m || n1 < n) && throw(DimensionMismatch("B of size ($m1,$n1) should have at least the same number of rows and columns than A of size ($m,$n)"))
if uplo == 'U'
gpu_call(A, B) do ctx, _A, _B
I = @cartesianidx _A
i, j = Tuple(I)
if j >= i
@inbounds _B[i,j] = _A[i,j]
end
return
end
else # uplo == 'L'
gpu_call(A, B) do ctx, _A, _B
I = @cartesianidx _A
i, j = Tuple(I)
if j <= i
@inbounds _B[i,j] = _A[i,j]
end
return
end
end
return B
end
end

## triangular

Expand All @@ -146,7 +175,7 @@ function LinearAlgebra.tril!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
I = @cartesianidx _A
i, j = Tuple(I)
if i < j - _d
_A[i, j] = 0
@inbounds _A[i, j] = zero(T)
end
return
end
Expand All @@ -158,7 +187,7 @@ function LinearAlgebra.triu!(A::AbstractGPUMatrix{T}, d::Integer = 0) where T
I = @cartesianidx _A
i, j = Tuple(I)
if j < i + _d
_A[i, j] = 0
@inbounds _A[i, j] = zero(T)
end
return
end
Expand Down
11 changes: 11 additions & 0 deletions test/testsuite/linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,17 @@
end
end

if isdefined(LinearAlgebra, :copytrito!)
@testset "copytrito!" begin
@testset for T in eltypes, uplo in ('L', 'U')
n = 16
A = rand(T,n,n)
B = zeros(T,n,n)
@test compare(copytrito!, AT, B, A, uplo)
end
end
end

@testset "copyto! for triangular" begin
for TR in (UpperTriangular, LowerTriangular)
@test compare(transpose!, AT, Array{Float32}(undef, 128, 32), rand(Float32, 32, 128))
Expand Down

0 comments on commit effeef9

Please sign in to comment.