Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend gesv_impl/gesv_batch_impl for work with oneMKL Interfaces #2001

Merged
merged 15 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 72 additions & 13 deletions dpnp/backend/extensions/lapack/gesv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
char *in_b,
const std::vector<sycl::event> &depends)
{
#if defined(USE_ONEMKL_INTERFACES)
// Temporary flag for build only
// FIXME: Need to implement by using lapack::getrf and lapack::getrs
std::logic_error("Not Implemented");
#else
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);
Expand All @@ -69,12 +64,31 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

const std::int64_t scratchpad_size =
std::int64_t scratchpad_size = 0;
sycl::event comp_event;
std::int64_t *ipiv = nullptr;

std::stringstream error_msg;
bool is_exception_caught = false;

#if defined(USE_ONEMKL_INTERFACES)
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
// For F-contiguous we use transpose::N.
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
// Since gesv takes F-contiguous as input, we use transpose::N.
oneapi::mkl::transpose trans = oneapi::mkl::transpose::N;

scratchpad_size = std::max(
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda),
mkl_lapack::getrs_scratchpad_size<T>(exec_q, trans, n, nrhs, lda, ldb));

#else
scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);

#endif // USE_ONEMKL_INTERFACES

T *scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);

std::int64_t *ipiv = nullptr;
try {
ipiv = helper::alloc_ipiv(n, exec_q);
} catch (const std::exception &e) {
Expand All @@ -83,12 +97,57 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
throw;
}
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved

std::stringstream error_msg;
bool is_exception_caught = false;
#if defined(USE_ONEMKL_INTERFACES)
sycl::event getrf_event;
try {
getrf_event = mkl_lapack::getrf(
exec_q,
n, // The order of the square matrix A (0 ≤ n).
// It must be a non-negative integer.
n, // The number of columns in the square matrix A (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the square matrix A (n x n).
lda, // The leading dimension of matrix A.
// It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);

sycl::event gesv_event;
comp_event = mkl_lapack::getrs(
exec_q,
trans, // Specifies the operation: whether or not to transpose
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
// and 'C' for conjugate transpose.
n, // The order of the square matrix A
// and the number of rows in matrix B (0 ≤ n).
// It must be a non-negative integer.
nrhs, // The number of right-hand sides,
// i.e., the number of columns in matrix B (0 ≤ nrhs).
a, // Pointer to the square matrix A (n x n).
lda, // The leading dimension of matrix A, must be at least max(1,
// n). It must be at least max(1, n).
ipiv, // Pointer to the output array of pivot indices that were used
// during factorization (n, ).
b, // Pointer to the matrix B of right-hand sides (ldb, nrhs).
ldb, // The leading dimension of matrix B, must be at least max(1,
// n).
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, {getrf_event});
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
gesv_utils::handle_lapack_exc(exec_q, lda, a, scratchpad_size,
scratchpad, ipiv, e, error_msg);
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during getrf() or "
"getrs() call:\n"
<< e.what();
}
#else
try {
gesv_event = mkl_lapack::gesv(
comp_event = mkl_lapack::gesv(
exec_q,
n, // The order of the square matrix A
// and the number of rows in matrix B (0 ≤ n).
Expand All @@ -114,6 +173,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
error_msg << "Unexpected SYCL exception caught during gesv() call:\n"
<< e.what();
}
#endif // USE_ONEMKL_INTERFACES

if (is_exception_caught) // an unexpected error occurs
{
Expand All @@ -125,7 +185,7 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
}

sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(gesv_event);
cgh.depends_on(comp_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
Expand All @@ -134,7 +194,6 @@ static sycl::event gesv_impl(sycl::queue &exec_q,
});

return ht_ev;
#endif // USE_ONEMKL_INTERFACES
}

std::pair<sycl::event, sycl::event>
Expand Down
176 changes: 164 additions & 12 deletions dpnp/backend/extensions/lapack/gesv_batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ typedef sycl::event (*gesv_batch_impl_fn_ptr_t)(
const std::int64_t,
const std::int64_t,
const std::int64_t,
#if defined(USE_ONEMKL_INTERFACES)
const std::int64_t,
const std::int64_t,
vlad-perevezentsev marked this conversation as resolved.
Show resolved Hide resolved
#endif // USE_ONEMKL_INTERFACES
char *,
char *,
const std::vector<sycl::event> &);
Expand All @@ -56,6 +60,10 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
const std::int64_t n,
const std::int64_t nrhs,
const std::int64_t batch_size,
#if defined(USE_ONEMKL_INTERFACES)
const std::int64_t stride_a,
const std::int64_t stride_b,
#endif // USE_ONEMKL_INTERFACES
char *in_a,
char *in_b,
const std::vector<sycl::event> &depends)
Expand All @@ -65,23 +73,147 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
T *a = reinterpret_cast<T *>(in_a);
T *b = reinterpret_cast<T *>(in_b);

const std::int64_t a_size = n * n;
const std::int64_t b_size = n * nrhs;

const std::int64_t lda = std::max<size_t>(1UL, n);
const std::int64_t ldb = std::max<size_t>(1UL, n);

std::int64_t scratchpad_size = 0;
sycl::event comp_event;
std::int64_t *ipiv = nullptr;
T *scratchpad = nullptr;

std::stringstream error_msg;
bool is_exception_caught = false;

#if defined(USE_ONEMKL_INTERFACES)
// Use transpose::T if the LU-factorized array is passed as C-contiguous.
// For F-contiguous we use transpose::N.
// Since gesv_batch takes F-contiguous as input, we use transpose::N.
oneapi::mkl::transpose trans = oneapi::mkl::transpose::N;
const std::int64_t stride_ipiv = n;

scratchpad_size = std::max(
mkl_lapack::getrs_batch_scratchpad_size<T>(exec_q, trans, n, nrhs, lda,
stride_a, stride_ipiv, ldb,
stride_b, batch_size),
mkl_lapack::getrf_batch_scratchpad_size<T>(exec_q, n, n, lda, stride_a,
stride_ipiv, batch_size));

scratchpad = helper::alloc_scratchpad<T>(scratchpad_size, exec_q);

// pass batch_size * n to allocate the memory for a 2D array of pivot
// indices
try {
ipiv = helper::alloc_ipiv(batch_size * n, exec_q);
} catch (const std::exception &e) {
if (scratchpad != nullptr)
sycl::free(scratchpad, exec_q);
throw;
}

sycl::event getrf_batch_event;
try {
getrf_batch_event = mkl_lapack::getrf_batch(
exec_q,
n, // The order of each square matrix in the batch; (0 ≤ n).
// It must be a non-negative integer.
n, // The number of columns in each matrix in the batch; (0 ≤ n).
// It must be a non-negative integer.
a, // Pointer to the batch of square matrices, each of size (n x n).
lda, // The leading dimension of each matrix in the batch.
stride_a, // Stride between consecutive matrices in the batch.
ipiv, // Pointer to the array of pivot indices for each matrix in
// the batch.
stride_ipiv, // Stride between pivot indices: Spacing between pivot
// arrays in 'ipiv'.
batch_size, // Stride between pivot index arrays in the batch.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);

comp_event = mkl_lapack::getrs_batch(
exec_q,
trans, // Specifies the operation: whether or not to transpose
// matrix A. Can be 'N' for no transpose, 'T' for transpose,
// and 'C' for conjugate transpose.
n, // The order of each square matrix A in the batch
// and the number of rows in each matrix B (0 ≤ n).
// It must be a non-negative integer.
nrhs, // The number of right-hand sides,
// i.e., the number of columns in each matrix B in the batch
// (0 ≤ nrhs).
a, // Pointer to the batch of square matrices A (n x n).
lda, // The leading dimension of each matrix A in the batch.
// It must be at least max(1, n).
stride_a, // Stride between individual matrices in the batch for
// matrix A.
ipiv, // Pointer to the batch of arrays of pivot indices.
stride_ipiv, // Stride between pivot index arrays in the batch.
b, // Pointer to the batch of matrices B (n, nrhs).
ldb, // The leading dimension of each matrix B in the batch.
// Must be at least max(1, n).
stride_b, // Stride between individual matrices in the batch for
// matrix B.
batch_size, // The number of matrices in the batch.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, {getrf_batch_event});
} catch (mkl_lapack::batch_error const &be) {
// Get the indices of matrices within the batch that encountered an
// error
auto error_matrices_ids = be.ids();

error_msg << "Singular matrix. Errors in matrices with IDs: ";
for (size_t i = 0; i < error_matrices_ids.size(); ++i) {
error_msg << error_matrices_ids[i];
if (i < error_matrices_ids.size() - 1) {
error_msg << ", ";
}
}
error_msg << ".";

if (scratchpad != nullptr)
sycl::free(scratchpad, exec_q);
if (ipiv != nullptr)
sycl::free(ipiv, exec_q);

throw LinAlgError(error_msg.str().c_str());
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
std::int64_t info = e.info();
if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else {
error_msg << "Unexpected MKL exception caught during getrf_batch() "
"or getrs_batch() call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during getrf() or "
"getrs() call:\n"
<< e.what();
}
#else
const std::int64_t a_size = n * n;
const std::int64_t b_size = n * nrhs;

// Get the number of independent linear streams
const std::int64_t n_linear_streams =
(batch_size > 16) ? 4 : ((batch_size > 4 ? 2 : 1));

const std::int64_t scratchpad_size =
scratchpad_size =
mkl_lapack::gesv_scratchpad_size<T>(exec_q, n, nrhs, lda, ldb);

T *scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
n_linear_streams, exec_q);
scratchpad = helper::alloc_scratchpad_batch<T>(scratchpad_size,
n_linear_streams, exec_q);

std::int64_t *ipiv = nullptr;
try {
ipiv = helper::alloc_ipiv_batch<T>(n, n_linear_streams, exec_q);
} catch (const std::exception &e) {
Expand All @@ -93,9 +225,6 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
// Computation events to manage dependencies for each linear stream
std::vector<std::vector<sycl::event>> comp_evs(n_linear_streams, depends);

std::stringstream error_msg;
bool is_exception_caught = false;

// Release GIL to avoid serialization of host task
// submissions to the same queue in OneMKL
py::gil_scoped_release release;
Expand Down Expand Up @@ -147,6 +276,7 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
// Update the event dependencies for the current stream
comp_evs[stream_id] = {gesv_event};
}
#endif // USE_ONEMKL_INTERFACES

if (is_exception_caught) // an unexpected error occurs
{
Expand All @@ -158,9 +288,13 @@ static sycl::event gesv_batch_impl(sycl::queue &exec_q,
}

sycl::event ht_ev = exec_q.submit([&](sycl::handler &cgh) {
#if defined(USE_ONEMKL_INTERFACES)
cgh.depends_on(comp_event);
#else
for (const auto &ev : comp_evs) {
cgh.depends_on(ev);
}
#endif // USE_ONEMKL_INTERFACES
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad, ipiv]() {
sycl::free(scratchpad, ctx);
Expand Down Expand Up @@ -242,9 +376,27 @@ std::pair<sycl::event, sycl::event>
const std::int64_t nrhs =
(dependent_vals_nd > 2) ? dependent_vals_shape[1] : 1;

sycl::event gesv_ev =
gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data,
sycl::event gesv_ev;

#if defined(USE_ONEMKL_INTERFACES)
auto const &coeff_matrix_strides = coeff_matrix.get_strides_vector();
auto const &dependent_vals_strides = dependent_vals.get_strides_vector();

// Get the strides for the batch matrices.
// Since the matrices are stored in F-contiguous order,
// the stride between batches is the last element in the strides vector.
const std::int64_t coeff_matrix_batch_stride = coeff_matrix_strides.back();
const std::int64_t dependent_vals_batch_stride =
dependent_vals_strides.back();

gesv_ev =
gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_batch_stride,
dependent_vals_batch_stride, coeff_matrix_data,
dependent_vals_data, depends);
#else
gesv_ev = gesv_batch_fn(exec_q, n, nrhs, batch_size, coeff_matrix_data,
dependent_vals_data, depends);
#endif // USE_ONEMKL_INTERFACES

sycl::event ht_ev = dpctl::utils::keep_args_alive(
exec_q, {coeff_matrix, dependent_vals}, {gesv_ev});
Expand Down
6 changes: 3 additions & 3 deletions dpnp/backend/extensions/lapack/getrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue,
typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue &,
antonwolfy marked this conversation as resolved.
Show resolved Hide resolved
oneapi::mkl::transpose,
const std::int64_t,
const std::int64_t,
Expand All @@ -56,7 +56,7 @@ typedef sycl::event (*getrs_impl_fn_ptr_t)(sycl::queue,
static getrs_impl_fn_ptr_t getrs_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event getrs_impl(sycl::queue exec_q,
static sycl::event getrs_impl(sycl::queue &exec_q,
oneapi::mkl::transpose trans,
const std::int64_t n,
const std::int64_t nrhs,
Expand Down Expand Up @@ -156,7 +156,7 @@ static sycl::event getrs_impl(sycl::queue exec_q,
}

std::pair<sycl::event, sycl::event>
getrs(sycl::queue exec_q,
getrs(sycl::queue &exec_q,
dpctl::tensor::usm_ndarray a_array,
dpctl::tensor::usm_ndarray ipiv_array,
dpctl::tensor::usm_ndarray b_array,
Expand Down
Loading
Loading