From 4e9513b8a4e56629a236b58504d609b1775a8236 Mon Sep 17 00:00:00 2001 From: Alexis Montoison <35051714+amontoison@users.noreply.github.com> Date: Mon, 16 Dec 2024 02:23:32 -0600 Subject: [PATCH] [CUSOLVER] Interface gesv! and gels! (#2406) --- lib/cusolver/CUSOLVER.jl | 1 + lib/cusolver/base.jl | 60 +++++++++++ lib/cusolver/dense.jl | 108 +++++++++++++++++++ lib/cusolver/dense_generic.jl | 14 --- lib/cusolver/error.jl | 2 + lib/cusolver/helpers.jl | 148 ++++++++++++++++++++++++++ lib/cusolver/multigpu.jl | 25 ----- lib/cusolver/sparse_factorizations.jl | 28 +---- test/libraries/cusolver/dense.jl | 26 +++++ 9 files changed, 347 insertions(+), 65 deletions(-) create mode 100644 lib/cusolver/helpers.jl diff --git a/lib/cusolver/CUSOLVER.jl b/lib/cusolver/CUSOLVER.jl index f4b222dba7..3b742abfa3 100644 --- a/lib/cusolver/CUSOLVER.jl +++ b/lib/cusolver/CUSOLVER.jl @@ -33,6 +33,7 @@ include("libcusolverMg.jl") include("libcusolverRF.jl") # low-level wrappers +include("helpers.jl") include("error.jl") include("base.jl") include("sparse.jl") diff --git a/lib/cusolver/base.jl b/lib/cusolver/base.jl index 894ef97530..89bbcb41be 100644 --- a/lib/cusolver/base.jl +++ b/lib/cusolver/base.jl @@ -63,3 +63,63 @@ function Base.convert(::Type{cusolverDirectMode_t}, direct::Char) throw(ArgumentError("Unknown direction mode $direct.")) end end + +function Base.convert(::Type{cusolverIRSRefinement_t}, irs::String) + if irs == "NOT_SET" + CUSOLVER_IRS_REFINE_NOT_SET + elseif irs == "NONE" + CUSOLVER_IRS_REFINE_NONE + elseif irs == "CLASSICAL" + CUSOLVER_IRS_REFINE_CLASSICAL + elseif "CLASSICAL_GMRES" + CUSOLVER_IRS_REFINE_CLASSICAL_GMRES + elseif "GMRES" + CUSOLVER_IRS_REFINE_GMRES + elseif "GMRES_GMRES" + CUSOLVER_IRS_REFINE_GMRES_GMRES + elseif "GMRES_NOPCOND" + CUSOLVER_IRS_REFINE_GMRES_NOPCOND + else + throw(ArgumentError("Unknown iterative refinement solver $irs.")) + end +end + +function Base.convert(::Type{cusolverPrecType_t}, T::String) + if T == "R_16F" + return CUSOLVER_R_16F + elseif T == "R_16BF" + return CUSOLVER_R_16BF + elseif T == "R_TF32" + return CUSOLVER_R_TF32 + elseif T == "R_32F" + return CUSOLVER_R_32F + elseif T == "R_64F" + return CUSOLVER_R_64F + elseif T == "C_16F" + return CUSOLVER_C_16F + elseif T == "C_16BF" + return CUSOLVER_C_16BF + elseif T == "C_TF32" + return CUSOLVER_C_TF32 + elseif T == "C_32F" + return CUSOLVER_C_32F + elseif T == "C_64F" + return CUSOLVER_C_64F + else + throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!")) + end +end + +function Base.convert(::Type{cusolverPrecType_t}, T::DataType) + if T === Float32 + return CUSOLVER_R_32F + elseif T === Float64 + return CUSOLVER_R_64F + elseif T === Complex{Float32} + return CUSOLVER_C_32F + elseif T === Complex{Float64} + return CUSOLVER_C_64F + else + throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!")) + end +end diff --git a/lib/cusolver/dense.jl b/lib/cusolver/dense.jl index ae7a4a9721..fe6d88300f 100644 --- a/lib/cusolver/dense.jl +++ b/lib/cusolver/dense.jl @@ -884,6 +884,114 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32), end end +# gesv +function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true, + residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL", + maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat + + params = CuSolverIRSParameters() + info = CuSolverIRSInformation() + n = checksquare(A) + nrhs = size(B, 2) + lda = max(1, stride(A, 2)) + ldb = max(1, stride(B, 2)) + ldx = max(1, stride(X, 2)) + niters = Ref{Cint}() + dh = dense_handle() + + if irs_precision == "AUTO" + (T == Float32) && (irs_precision = "R_32F") + (T == Float64) && (irs_precision = "R_64F") + (T == ComplexF32) && (irs_precision = "C_32F") + (T == ComplexF64) && (irs_precision = "C_64F") + else + (T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported.")) + (T == Float64) && (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported.")) + (T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported.")) + (T == ComplexF64) && (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported.")) + end + cusolverDnIRSParamsSetSolverMainPrecision(params, T) + cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision) + cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver) + (tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol) + (tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner) + (maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters) + (maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner) + fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params) + residual_history && cusolverDnIRSInfosRequestResidual(info) + + function bufferSize() + buffer_size = Ref{Csize_t}(0) + cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size) + return buffer_size[] + end + + with_workspace(dh.workspace_gpu, bufferSize) do buffer + cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb, + X, ldx, buffer, sizeof(buffer), niters, dh.info) + end + + # Copy the solver flag and delete the device memory + flag = @allowscalar dh.info[1] + chklapackerror(flag |> BlasInt) + + return X, info +end + +# gels +function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true, + residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL", + maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat + + params = CuSolverIRSParameters() + info = CuSolverIRSInformation() + m,n = size(A) + nrhs = size(B, 2) + lda = max(1, stride(A, 2)) + ldb = max(1, stride(B, 2)) + ldx = max(1, stride(X, 2)) + niters = Ref{Cint}() + dh = dense_handle() + + if irs_precision == "AUTO" + (T == Float32) && (irs_precision = "R_32F") + (T == Float64) && (irs_precision = "R_64F") + (T == ComplexF32) && (irs_precision = "C_32F") + (T == ComplexF64) && (irs_precision = "C_64F") + else + (T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported.")) + (T == Float64) && (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported.")) + (T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported.")) + (T == ComplexF64) && (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported.")) + end + cusolverDnIRSParamsSetSolverMainPrecision(params, T) + cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision) + cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver) + (tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol) + (tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner) + (maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters) + (maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner) + fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params) + residual_history && cusolverDnIRSInfosRequestResidual(info) + + function bufferSize() + buffer_size = Ref{Csize_t}(0) + cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size) + return buffer_size[] + end + + with_workspace(dh.workspace_gpu, bufferSize) do buffer + cusolverDnIRSXgels(dh, params, info, m, n, nrhs, A, lda, B, ldb, + X, ldx, buffer, sizeof(buffer), niters, dh.info) + end + + # Copy the solver flag and delete the device memory + flag = @allowscalar dh.info[1] + chklapackerror(flag |> BlasInt) + + return X, info +end + # LAPACK for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64) @eval begin diff --git a/lib/cusolver/dense_generic.jl b/lib/cusolver/dense_generic.jl index 93997424b5..7216169a9a 100644 --- a/lib/cusolver/dense_generic.jl +++ b/lib/cusolver/dense_generic.jl @@ -1,17 +1,3 @@ -mutable struct CuSolverParameters - parameters::cusolverDnParams_t - - function CuSolverParameters() - parameters_ref = Ref{cusolverDnParams_t}() - cusolverDnCreateParams(parameters_ref) - obj = new(parameters_ref[]) - finalizer(cusolverDnDestroyParams, obj) - obj - end -end - -Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters - # Xpotrf function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat} chkuplo(uplo) diff --git a/lib/cusolver/error.jl b/lib/cusolver/error.jl index d8ad0c6796..9edf271839 100644 --- a/lib/cusolver/error.jl +++ b/lib/cusolver/error.jl @@ -29,6 +29,8 @@ function description(err) "an internal operation failed" elseif err.code == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED "the matrix type is not supported." + elseif err.code == CUSOLVER_STATUS_NOT_SUPPORTED + "the parameter combination is not supported." else "no description for this error" end diff --git a/lib/cusolver/helpers.jl b/lib/cusolver/helpers.jl new file mode 100644 index 0000000000..cbc80f32df --- /dev/null +++ b/lib/cusolver/helpers.jl @@ -0,0 +1,148 @@ +# cuSOLVER helper functions + +## SparseQRInfo + +mutable struct SparseQRInfo + info::csrqrInfo_t + + function SparseQRInfo() + info_ref = Ref{csrqrInfo_t}() + cusolverSpCreateCsrqrInfo(info_ref) + obj = new(info_ref[]) + finalizer(cusolverSpDestroyCsrqrInfo, obj) + obj + end +end + +Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info + + +## SparseCholeskyInfo + +mutable struct SparseCholeskyInfo + info::csrcholInfo_t + + function SparseCholeskyInfo() + info_ref = Ref{csrcholInfo_t}() + cusolverSpCreateCsrcholInfo(info_ref) + obj = new(info_ref[]) + finalizer(cusolverSpDestroyCsrcholInfo, obj) + obj + end +end + +Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info + + +## CuSolverParameters + +mutable struct CuSolverParameters + parameters::cusolverDnParams_t + + function CuSolverParameters() + parameters_ref = Ref{cusolverDnParams_t}() + cusolverDnCreateParams(parameters_ref) + obj = new(parameters_ref[]) + finalizer(cusolverDnDestroyParams, obj) + obj + end +end + +Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters + + +## CuSolverIRSParameters + +mutable struct CuSolverIRSParameters + parameters::cusolverDnIRSParams_t + + function CuSolverIRSParameters() + parameters_ref = Ref{cusolverDnIRSParams_t}() + cusolverDnIRSParamsCreate(parameters_ref) + obj = new(parameters_ref[]) + finalizer(cusolverDnIRSParamsDestroy, obj) + obj + end +end + +Base.unsafe_convert(::Type{cusolverDnIRSParams_t}, params::CuSolverIRSParameters) = params.parameters + +function get_info(params::CuSolverIRSParameters, field::Symbol) + if field == :maxiters + maxiters = Ref{Cint}() + cusolverDnIRSParamsGetMaxIters(params, maxiters) + return maxiters[] + else + error("The information $field is incorrect.") + end +end + + +## CuSolverIRSInformation + +mutable struct CuSolverIRSInformation + information::cusolverDnIRSInfos_t + + function CuSolverIRSInformation() + info_ref = Ref{cusolverDnIRSInfos_t}() + cusolverDnIRSInfosCreate(info_ref) + obj = new(info_ref[]) + finalizer(cusolverDnIRSInfosDestroy, obj) + obj + end +end + +Base.unsafe_convert(::Type{cusolverDnIRSInfos_t}, info::CuSolverIRSInformation) = info.information + +function get_info(info::CuSolverIRSInformation, field::Symbol) + if field == :niters + niters = Ref{Cint}() + cusolverDnIRSInfosGetNiters(info, niters) + return niters[] + elseif field == :outer_niters + outer_niters = Ref{Cint}() + cusolverDnIRSInfosGetOuterNiters(info, outer_niters) + return outer_niters[] + # elseif field == :residual_history + # residual_history = Ref{Ptr{Cvoid} + # cusolverDnIRSInfosGetResidualHistory(info, residual_history) + # return residual_history[] + elseif field == :maxiters + maxiters = Ref{Cint}() + cusolverDnIRSInfosGetMaxIters(info, maxiters) + return maxiters[] + else + error("The information $field is incorrect.") + end +end + + +## MatrixDescriptor + +mutable struct MatrixDescriptor + desc::cudaLibMgMatrixDesc_t + + function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) ) + desc = Ref{cudaLibMgMatrixDesc_t}() + cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid) + return new(desc[]) + end +end + +Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc + + +## DeviceGrid + +mutable struct DeviceGrid + desc::cudaLibMgGrid_t + + function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping) + @assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1." + desc = Ref{cudaLibMgGrid_t}() + cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping) + return new(desc[]) + end +end + +Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc diff --git a/lib/cusolver/multigpu.jl b/lib/cusolver/multigpu.jl index 71a8b68beb..15b13154c8 100644 --- a/lib/cusolver/multigpu.jl +++ b/lib/cusolver/multigpu.jl @@ -7,31 +7,6 @@ # NOTE: in the cublasMg preview, which also relies on this functionality, a separate library # called 'cudalibmg' is introduced. factor this out when we actually ship that. -mutable struct MatrixDescriptor - desc::cudaLibMgMatrixDesc_t - - function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) ) - desc = Ref{cudaLibMgMatrixDesc_t}() - cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid) - return new(desc[]) - end -end - -Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc - -mutable struct DeviceGrid - desc::cudaLibMgGrid_t - - function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping) - @assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1." - desc = Ref{cudaLibMgGrid_t}() - cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping) - return new(desc[]) - end -end - -Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc - function allocateBuffers(n_row_devs, n_col_devs, mat::Matrix) mat_row_block_size = div(size(mat, 1), n_row_devs) mat_col_block_size = div(size(mat, 2), n_col_devs) diff --git a/lib/cusolver/sparse_factorizations.jl b/lib/cusolver/sparse_factorizations.jl index 1d3cf142e3..4c504ca8f1 100644 --- a/lib/cusolver/sparse_factorizations.jl +++ b/lib/cusolver/sparse_factorizations.jl @@ -1,16 +1,4 @@ -mutable struct SparseQRInfo - info::csrqrInfo_t - - function SparseQRInfo() - info_ref = Ref{csrqrInfo_t}() - cusolverSpCreateCsrqrInfo(info_ref) - obj = new(info_ref[]) - finalizer(cusolverSpDestroyCsrqrInfo, obj) - obj - end -end - -Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info +## ----- Sparse QR ----- mutable struct SparseQR{T <: BlasFloat} <: Factorization{T} n::Cint @@ -155,19 +143,7 @@ for (bname, iname, fname, sname, pname, elty, relty) in end end -mutable struct SparseCholeskyInfo - info::csrcholInfo_t - - function SparseCholeskyInfo() - info_ref = Ref{csrcholInfo_t}() - cusolverSpCreateCsrcholInfo(info_ref) - obj = new(info_ref[]) - finalizer(cusolverSpDestroyCsrcholInfo, obj) - obj - end -end - -Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info +## ----- Sparse Cholesky ----- mutable struct SparseCholesky{T <: BlasFloat} <: Factorization{T} n::Cint diff --git a/test/libraries/cusolver/dense.jl b/test/libraries/cusolver/dense.jl index 8b7590bef4..ce78364247 100644 --- a/test/libraries/cusolver/dense.jl +++ b/test/libraries/cusolver/dense.jl @@ -5,10 +5,36 @@ using LinearAlgebra: BlasInt m = 15 n = 10 +p = 5 l = 13 k = 1 @testset "elty = $elty" for elty in [Float32, Float64, ComplexF32, ComplexF64] + @testset "gesv!" begin + A = rand(elty, n, n) + X = zeros(elty, n, p) + B = rand(elty, n, p) + dA = CuArray(A) + dX = CuArray(X) + dB = CuArray(B) + CUSOLVER.gesv!(dX, dA, dB) + tol = real(elty) |> eps |> sqrt + dR = dB - dA * dX + @test norm(dR) <= tol + end + + @testset "gels!" begin + A = rand(elty, m, n) + X = zeros(elty, n, p) + B = A * rand(elty, n, p) # ensure that AX = B is consistent + dA = CuArray(A) + dX = CuArray(X) + dB = CuArray(B) + CUSOLVER.gels!(dX, dA, dB) + tol = real(elty) |> eps |> sqrt + dR = dB - dA * dX + end + @testset "geqrf! -- orgqr!" begin A = rand(elty, m, n) dA = CuArray(A)