Skip to content

Commit

Permalink
Add two functions BufferSize
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Dec 14, 2024
1 parent df4afa1 commit fc6f0be
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions lib/cusolver/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Boo
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
Expand All @@ -919,13 +920,15 @@ function gesv!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Boo
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

dh = dense_handle()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size)
function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgesv_bufferSize(dh, params, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, buffer_size[]) do buffer
with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgesv(dh, params, info, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, buffer_size, niter, dh.info)
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
Expand All @@ -948,6 +951,7 @@ function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Boo
ldb = max(1, stride(B, 2))
ldx = max(1, stride(X, 2))
niters = Ref{Cint}()
dh = dense_handle()

if irs_precision == "AUTO"
(T == Float32) && (irs_precision = "R_32F")
Expand All @@ -970,13 +974,15 @@ function gels!(X::CuVecOrMat{T}, A::CuMatrix{T}, B::CuVecOrMat{T}; fallback::Boo
fallback ? cusolverDnIRSParamsEnableFallback(params) : cusolverDnIRSParamsDisableFallback(params)
residual_history && cusolverDnIRSInfosRequestResidual(info)

dh = dense_handle()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size)
function bufferSize()
buffer_size = Ref{Csize_t}(0)
cusolverDnIRSXgels_bufferSize(dh, params, m, n, nrhs, buffer_size)
return buffer_size[]
end

with_workspace(dh.workspace_gpu, buffer_size[]) do buffer
with_workspace(dh.workspace_gpu, bufferSize) do buffer
cusolverDnIRSXgels(dh, params, info, m, n, nrhs, A, lda, B, ldb,
X, ldx, buffer, buffer_size, niters, dh.info)
X, ldx, buffer, sizeof(buffer), niters, dh.info)
end

# Copy the solver flag and delete the device memory
Expand Down

0 comments on commit fc6f0be

Please sign in to comment.