diff --git a/src/host/linalg.jl b/src/host/linalg.jl index a2a99019..f0bb2436 100644 --- a/src/host/linalg.jl +++ b/src/host/linalg.jl @@ -372,6 +372,127 @@ function generic_matmatmul!(C::AbstractArray{R}, A::AbstractArray{T}, B::Abstrac C end +function generic_trimatmul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} + if size(A,2) != size(B,1) + throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + upper = uploc == 'U' + unit = isunitc == 'U' + + function trimatmul(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : A[i,i]) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Cij += A[i,k] * B[k,j] + end + C[i,j] += Cij + end + + return + end + + function trimatmul_t(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : A[i,i]) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Ctmp += A[k,i] * B[k,j] + end + C[i,j] += Cij + end + + return + end + + function trimatmul_c(ctx, C, A, B) + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + l, m, n = size(A, 1), size(B, 1), size(B, 2) + + @inbounds if i <= l && j <= n + z2 = zero(A[i,1] * B[1,j] + A[i,1] * B[1,j]) + Cij = convert(promote_type(R, typeof(z2)), z2) + Cij += (unit ? one(Cij) : conj(A[i,i])) * B[i,j] + for k in (upper ? (i + 1) : 1):(upper ? m : (i - 1)) + Cij += conj(A[k,i]) * B[k,j] + end + C[i,j] += Cij + end + + return + end + + if tfun === identity + gpu_call(trimatmul, C, A, B; name="trimatmul") + elseif tfun == transpose + gpu_call(trimatmul_t, C, A, B; name="trimatmul") + elseif tfun === adjoint + gpu_call(trimatmul_c, C, A, B; name="trimatmul") + else + error("Not supported") + end + + C +end + +function generic_mattrimul!(C::AbstractGPUVecOrMat{R}, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix{T}, B::AbstractGPUVecOrMat{S}) where {T,S,R} + if size(A,2) != size(B,1) + throw(DimensionMismatch("matrix A has dimensions $(size(A)), matrix B has dimensions $(size(B))")) + end + if size(C,1) != size(A,1) || size(C,2) != size(B,2) + throw(DimensionMismatch("result C has dimensions $(size(C)), needs $((size(A,1),size(B,2)))")) + end + if isempty(A) || isempty(B) + return fill!(C, zero(R)) + end + + upper = uploc == 'U' + unit = isunitc == 'U' + + # tA = tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C' + + gpu_call(C, A, B; name="mattrimul") do ctx, C, A, B + idx = @linearidx C + assume.(size(C) .> 0) + i, j = @inbounds Tuple(CartesianIndices(C)[idx])..., 1 + + @inbounds if i <= size(A,1) && j <= size(B,2) + z2 = zero(A[i, 1]*B[1, j] + A[i, 1]*B[1, j]) + Ctmp = convert(promote_type(R, typeof(z2)), z2) + Ctmp += A[i, j] * (unit ? one(Ctmp) : B[j, j]) + for k in (upper ? 1 : (j + 1)):(upper ? (j - 1) : (size(A,2))) + Ctmp += A[i, k]*B[k, j] + end + C[i,j] += Ctmp + end + + return + end + + C +end + function LinearAlgebra.generic_matvecmul!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, _add::MulAddMul = MulAddMul()) generic_matmatmul!(C, wrap(A, tA), B, _add.alpha, _add.beta) end @@ -380,6 +501,18 @@ function LinearAlgebra.generic_matmatmul!(C::AbstractGPUVecOrMat, tA, tB, A::Abs generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), _add.alpha, _add.beta) end +if VERSION >= v"1.10-" + +function LinearAlgebra.generic_trimatmul!(C::AbstractGPUVecOrMat, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUVecOrMat) + generic_trimatmul!(C, uploc, isunitc, tfun, A, B) +end + +function LinearAlgebra.generic_mattrimul!(C::AbstractGPUMatrix, uploc, isunitc, tfun::Function, A::AbstractGPUMatrix, B::AbstractGPUMatrix) + generic_mattrimul!(C, uploc, isunitc, tfun, A, B) +end + +end + if VERSION < v"1.10.0-DEV.1365" # catch other functions that are called by LinearAlgebra's mul! function LinearAlgebra.gemv!(C::AbstractGPUVector, tA::AbstractChar, A::AbstractGPUMatrix, B::AbstractGPUVector, a::Number, b::Number) diff --git a/test/testsuite/linalg.jl b/test/testsuite/linalg.jl index da858608..24549a48 100644 --- a/test/testsuite/linalg.jl +++ b/test/testsuite/linalg.jl @@ -132,6 +132,97 @@ @test istriu(A) == istriu(B) end end + + if VERSION >= v"1.10-" + @testset "trimatmul" begin + n = 128 + b = AT(rand(Float32, n)) + B = AT(rand(Float32, n, n)) + + At = UpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + Ct = AT(zeros(Float32, n)) + C = zeros(Float32, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, b) + mul!(C, A, b) + @test Ct ≈ C + + Ct = AT(zeros(Float32, n, n)) + C = zeros(Float32, n, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, B) + mul!(C, A, B) + @test Ct ≈ C + + At = UnitUpperTriangular(AT(rand(Float32, n,n))) + A = AT(At) + Ct = AT(zeros(Float32, n)) + C = zeros(Float32, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, b) + mul!(C, A, b) + @test Ct ≈ C + + Ct = AT(zeros(Float32, n, n)) + C = zeros(Float32, n, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, B) + mul!(C, A, B) + @test Ct ≈ C + + + At = LowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + Ct = AT(zeros(Float32, n)) + C = zeros(Float32, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, b) + mul!(C, A, b) + @test Ct ≈ C + + Ct = AT(zeros(Float32, n, n)) + C = zeros(Float32, n, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, B) + mul!(C, A, B) + @test Ct ≈ C + + At = UnitLowerTriangular(AT(rand(Float32, n,n))) + A = AT(At) + Ct = AT(zeros(Float32, n)) + C = zeros(Float32, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, b) + mul!(C, A, b) + @test Ct ≈ C + + Ct = AT(zeros(Float32, n, n)) + C = zeros(Float32, n, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, B) + mul!(C, A, B) + @test Ct ≈ C + + At = UnitLowerTriangular(AT(rand(ComplexF32, n,n))) + A = AT(At) + b = AT(rand(ComplexF32, n)) + B = AT(rand(ComplexF32, n, n)) + Ct = AT(zeros(ComplexF32, n)) + C = zeros(ComplexF32, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, b) + mul!(C, A, b) + @test Ct ≈ C + + Ct = AT(zeros(ComplexF32, n, n)) + C = zeros(ComplexF32, n, n) + + LinearAlgebra.generic_trimatmul!(Ct, At, B) + mul!(C, A, B) + @test Ct ≈ C + end + end end @testset "diagonal" begin