Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CUSOLVER] Interface gesv! and gels! #2406

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions lib/cusolver/CUSOLVER.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ include("libcusolverMg.jl")
include("libcusolverRF.jl")

# low-level wrappers
include("helpers.jl")
include("error.jl")
include("base.jl")
include("sparse.jl")
Expand Down
60 changes: 60 additions & 0 deletions lib/cusolver/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,63 @@ function Base.convert(::Type{cusolverDirectMode_t}, direct::Char)
throw(ArgumentError("Unknown direction mode $direct."))
end
end

function Base.convert(::Type{cusolverIRSRefinement_t}, irs::String)
if irs == "NOT_SET"
CUSOLVER_IRS_REFINE_NOT_SET
elseif irs == "NONE"
CUSOLVER_IRS_REFINE_NONE
elseif irs == "CLASSICAL"
CUSOLVER_IRS_REFINE_CLASSICAL
elseif "CLASSICAL_GMRES"
CUSOLVER_IRS_REFINE_CLASSICAL_GMRES
elseif "GMRES"
CUSOLVER_IRS_REFINE_GMRES
elseif "GMRES_GMRES"
CUSOLVER_IRS_REFINE_GMRES_GMRES
elseif "GMRES_NOPCOND"
CUSOLVER_IRS_REFINE_GMRES_NOPCOND
else
throw(ArgumentError("Unknown iterative refinement solver $irs."))
end
end

function Base.convert(::Type{cusolverPrecType_t}, T::String)
if T == "R_16F"
return CUSOLVER_R_16F
elseif T == "R_16BF"
return CUSOLVER_R_16BF
elseif T == "R_TF32"
return CUSOLVER_R_TF32
elseif T == "R_32F"
return CUSOLVER_R_32F
elseif T == "R_64F"
return CUSOLVER_R_64F
elseif T == "C_16F"
return CUSOLVER_C_16F
elseif T == "C_16BF"
return CUSOLVER_C_16BF
elseif T == "C_TF32"
return CUSOLVER_C_TF32
elseif T == "C_32F"
return CUSOLVER_C_32F
elseif T == "C_64F"
return CUSOLVER_C_64F
else
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
end
end

function Base.convert(::Type{cusolverPrecType_t}, T::DataType)
if T === Float32
return CUSOLVER_R_32F
elseif T === Float64
return CUSOLVER_R_64F
elseif T === Complex{Float32}
return CUSOLVER_C_32F
elseif T === Complex{Float64}
return CUSOLVER_C_64F
else
throw(ArgumentError("cusolverPrecType_t equivalent for input type $T does not exist!"))
end
end
108 changes: 108 additions & 0 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -884,6 +884,114 @@ for (fname, elty) in ((:cusolverDnSpotrfBatched, :Float32),
end
end

# gesv
function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat

params = CuSolverIRSParameters()
info = CuSolverIRSInformation()
n = checksquare(A)
nrhs = size(B, 2)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
(T == Float64) && (irs_precision = "R_64F")
(T == ComplexF32) && (irs_precision = "C_32F")
(T == ComplexF64) && (irs_precision = "C_64F")
else
(T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == Float64) && (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
(T == ComplexF64) && (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
end
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
flag = @allowscalar dh.info[1]
chklapackerror(flag |> BlasInt)

return X, info
end

# gels
function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Bool=true,
residual_history::Bool=false, irs_precision::String="AUTO", refinement_solver::String="CLASSICAL",
maxiters::Int=0, maxiters_inner::Int=0, tol::Float64=0.0, tol_inner=Float64=0.0) where T <: BlasFloat

params = CuSolverIRSParameters()
info = CuSolverIRSInformation()
m,n = size(A)
nrhs = size(B, 2)
lda = max(1, stride(A, 2))
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
(T == Float64) && (irs_precision = "R_64F")
(T == ComplexF32) && (irs_precision = "C_32F")
(T == ComplexF64) && (irs_precision = "C_64F")
else
(T == Float32) && (irs_precision ∈ ("R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == Float64) && (irs_precision ∈ ("R_64F", "R_32F", "R_16F", "R_16BF", "R_TF32") || error("$irs_precision is not supported."))
(T == ComplexF32) && (irs_precision ∈ ("C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
(T == ComplexF64) && (irs_precision ∈ ("C_64F", "C_32F", "C_16F", "C_16BF", "C_TF32") || error("$irs_precision is not supported."))
end
cusolverDnIRSParamsSetSolverMainPrecision(params, T)
cusolverDnIRSParamsSetSolverLowestPrecision(params, irs_precision)
cusolverDnIRSParamsSetRefinementSolver(params, refinement_solver)
(tol != 0.0) && cusolverDnIRSParamsSetTol(params, tol)
(tol_inner != 0.0) && cusolverDnIRSParamsSetTolInner(params, tol_inner)
(maxiters != 0) && cusolverDnIRSParamsSetMaxIters(params, maxiters)
(maxiters_inner != 0) && cusolverDnIRSParamsSetMaxItersInner(params, maxiters_inner)
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgels(dh, params, info, m, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
flag = @allowscalar dh.info[1]
chklapackerror(flag |> BlasInt)

return X, info
end

# LAPACK
for elty in (:Float32, :Float64, :ComplexF32, :ComplexF64)
@eval begin
Expand Down
14 changes: 0 additions & 14 deletions lib/cusolver/dense_generic.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,3 @@
mutable struct CuSolverParameters
parameters::cusolverDnParams_t

function CuSolverParameters()
parameters_ref = Ref{cusolverDnParams_t}()
cusolverDnCreateParams(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnDestroyParams, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters

# Xpotrf
function Xpotrf!(uplo::Char, A::StridedCuMatrix{T}) where {T <: BlasFloat}
chkuplo(uplo)
Expand Down
2 changes: 2 additions & 0 deletions lib/cusolver/error.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ function description(err)
"an internal operation failed"
elseif err.code == CUSOLVER_STATUS_MATRIX_TYPE_NOT_SUPPORTED
"the matrix type is not supported."
elseif err.code == CUSOLVER_STATUS_NOT_SUPPORTED
"the parameter combination is not supported."
else
"no description for this error"
end
Expand Down
148 changes: 148 additions & 0 deletions lib/cusolver/helpers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# cuSOLVER helper functions

## SparseQRInfo

mutable struct SparseQRInfo
info::csrqrInfo_t

function SparseQRInfo()
info_ref = Ref{csrqrInfo_t}()
cusolverSpCreateCsrqrInfo(info_ref)
obj = new(info_ref[])
finalizer(cusolverSpDestroyCsrqrInfo, obj)
obj
end
end

Base.unsafe_convert(::Type{csrqrInfo_t}, info::SparseQRInfo) = info.info


## SparseCholeskyInfo

mutable struct SparseCholeskyInfo
info::csrcholInfo_t

function SparseCholeskyInfo()
info_ref = Ref{csrcholInfo_t}()
cusolverSpCreateCsrcholInfo(info_ref)
obj = new(info_ref[])
finalizer(cusolverSpDestroyCsrcholInfo, obj)
obj
end
end

Base.unsafe_convert(::Type{csrcholInfo_t}, info::SparseCholeskyInfo) = info.info


## CuSolverParameters

mutable struct CuSolverParameters
parameters::cusolverDnParams_t

function CuSolverParameters()
parameters_ref = Ref{cusolverDnParams_t}()
cusolverDnCreateParams(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnDestroyParams, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnParams_t}, params::CuSolverParameters) = params.parameters


## CuSolverIRSParameters

mutable struct CuSolverIRSParameters
parameters::cusolverDnIRSParams_t

function CuSolverIRSParameters()
parameters_ref = Ref{cusolverDnIRSParams_t}()
cusolverDnIRSParamsCreate(parameters_ref)
obj = new(parameters_ref[])
finalizer(cusolverDnIRSParamsDestroy, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnIRSParams_t}, params::CuSolverIRSParameters) = params.parameters

function get_info(params::CuSolverIRSParameters, field::Symbol)
if field == :maxiters
maxiters = Ref{Cint}()
cusolverDnIRSParamsGetMaxIters(params, maxiters)
return maxiters[]
else
error("The information $field is incorrect.")
end
end


## CuSolverIRSInformation

mutable struct CuSolverIRSInformation
information::cusolverDnIRSInfos_t

function CuSolverIRSInformation()
info_ref = Ref{cusolverDnIRSInfos_t}()
cusolverDnIRSInfosCreate(info_ref)
obj = new(info_ref[])
finalizer(cusolverDnIRSInfosDestroy, obj)
obj
end
end

Base.unsafe_convert(::Type{cusolverDnIRSInfos_t}, info::CuSolverIRSInformation) = info.information

function get_info(info::CuSolverIRSInformation, field::Symbol)
if field == :niters
niters = Ref{Cint}()
cusolverDnIRSInfosGetNiters(info, niters)
return niters[]
elseif field == :outer_niters
outer_niters = Ref{Cint}()
cusolverDnIRSInfosGetOuterNiters(info, outer_niters)
return outer_niters[]
# elseif field == :residual_history
# residual_history = Ref{Ptr{Cvoid}
# cusolverDnIRSInfosGetResidualHistory(info, residual_history)
# return residual_history[]
elseif field == :maxiters
maxiters = Ref{Cint}()
cusolverDnIRSInfosGetMaxIters(info, maxiters)
return maxiters[]
else
error("The information $field is incorrect.")
end
end


## MatrixDescriptor

mutable struct MatrixDescriptor
desc::cudaLibMgMatrixDesc_t

function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
desc = Ref{cudaLibMgMatrixDesc_t}()
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc


## DeviceGrid

mutable struct DeviceGrid
desc::cudaLibMgGrid_t

function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
desc = Ref{cudaLibMgGrid_t}()
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc
25 changes: 0 additions & 25 deletions lib/cusolver/multigpu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,6 @@
# NOTE: in the cublasMg preview, which also relies on this functionality, a separate library
# called 'cudalibmg' is introduced. factor this out when we actually ship that.

mutable struct MatrixDescriptor
desc::cudaLibMgMatrixDesc_t

function MatrixDescriptor(a, grid; rowblocks = size(a, 1), colblocks = size(a, 2), elta=eltype(a) )
desc = Ref{cudaLibMgMatrixDesc_t}()
cusolverMgCreateMatrixDesc(desc, size(a, 1), size(a, 2), rowblocks, colblocks, elta, grid)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgMatrixDesc_t}, obj::MatrixDescriptor) = obj.desc

mutable struct DeviceGrid
desc::cudaLibMgGrid_t

function DeviceGrid(num_row_devs, num_col_devs, deviceIds, mapping)
@assert num_row_devs == 1 "Only 1-D column block cyclic is supported, so numRowDevices must be equal to 1."
desc = Ref{cudaLibMgGrid_t}()
cusolverMgCreateDeviceGrid(desc, num_row_devs, num_col_devs, deviceIds, mapping)
return new(desc[])
end
end

Base.unsafe_convert(::Type{cudaLibMgGrid_t}, obj::DeviceGrid) = obj.desc

function allocateBuffers(n_row_devs, n_col_devs, mat::Matrix)
mat_row_block_size = div(size(mat, 1), n_row_devs)
mat_col_block_size = div(size(mat, 2), n_col_devs)
Expand Down
Loading