forked from IntelPython/dpnp
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement of dpnp.linalg.slogdet() (IntelPython#1607)
* Add a new impl of dpnp.linalg._lu_factor * Get dev_info_array after calling getrf * Add an extra dev_info return to _lu_factor * qwe * Add a logic for a.ndim > 2 in _lu_factor * Add an implementation of dpnp.linalg.slogdet * Add a new test_norms.py file in cupy tests * Expand test scope in public CI * A small update _lu_factor func * Remove w/a for dpnp.count_nonzero in slogdet * getrf returns pair of events and uses dpctl.utils.keep_args_alive * Update dpnp.linalg.det using slogdet * Add new cupy tests for dpnp.linalg.det * Add ipiv_vecs and dev_info_vecs in _lu_factor for the batch case * Skip test_det on CPU due to bug in MKL * Small update of cupy tests in test_norms.py * Add support of complex dtype for dpnp.diagonal and update test_diagonal * lu_factor func returns the result of LU decomposition as c-contiguous and add explanatory comments * Add getrf_batch MKL extension * Update docstring for slogdet * Add more tests * Remove accidentally added file * Modify sign parameter calculation * Remove the old backend implementation of dpnp_det * qwe * Keep lexographical order * Add dpnp_slogdet to dpnp_utils_linalg * Move _lu_factor above * A minor update * A minor changes for _lu_factor * Remove trash files * Use getrf_batch only on CPU * Update tests for dpnp.linalg.slogdet * Address remarks * Add _real_type func * Add test_det in test_usm_type * Add more checks in getrf and getf_batch functions * Improve error handler in getrf_impl * Improve error handler in getrf_batch_impl * dev_info is allocated as zeros * Remove skipif for singular tests * Implement _lu_factor logic with dev_info as a python list * Update getrf_rf error handler with mkl_lapack::batch_error * Remove passing n parameter to _getrf * Add a new test_slogdet_singular_matrix_3D test * Update tests for dpnp.linalg.det * Use is_exception_caught flag in getrf and getrf_batch error handler * Update gesv error handler * Reshape results after calling getrf_batch * Add a new dpnp.linalg.det impl and refresh dpnp_utils_linalg * Remove Limitations from dpnp_det and dpnp_slogdet docstings * Address remarks * Remove det_dtype variable and use the abs val of diag for det * Expand cupy tests for dpnp.linalg.det() * Update TestDet and TestSlogdet
- Loading branch information
1 parent
3fdb921
commit 7e54eb8
Showing
20 changed files
with
1,625 additions
and
135 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,256 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#include <pybind11/pybind11.h> | ||
|
||
// dpctl tensor headers | ||
#include "utils/memory_overlap.hpp" | ||
#include "utils/type_utils.hpp" | ||
|
||
#include "getrf.hpp" | ||
#include "types_matrix.hpp" | ||
|
||
#include "dpnp_utils.hpp" | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
namespace mkl_lapack = oneapi::mkl::lapack; | ||
namespace py = pybind11; | ||
namespace type_utils = dpctl::tensor::type_utils; | ||
|
||
typedef sycl::event (*getrf_impl_fn_ptr_t)(sycl::queue, | ||
const std::int64_t, | ||
char *, | ||
std::int64_t, | ||
std::int64_t *, | ||
py::list, | ||
std::vector<sycl::event> &, | ||
const std::vector<sycl::event> &); | ||
|
||
static getrf_impl_fn_ptr_t getrf_dispatch_vector[dpctl_td_ns::num_types]; | ||
|
||
template <typename T> | ||
static sycl::event getrf_impl(sycl::queue exec_q, | ||
const std::int64_t n, | ||
char *in_a, | ||
std::int64_t lda, | ||
std::int64_t *ipiv, | ||
py::list dev_info, | ||
std::vector<sycl::event> &host_task_events, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
type_utils::validate_type_for_device<T>(exec_q); | ||
|
||
T *a = reinterpret_cast<T *>(in_a); | ||
|
||
const std::int64_t scratchpad_size = | ||
mkl_lapack::getrf_scratchpad_size<T>(exec_q, n, n, lda); | ||
T *scratchpad = nullptr; | ||
|
||
std::stringstream error_msg; | ||
std::int64_t info = 0; | ||
bool is_exception_caught = false; | ||
|
||
sycl::event getrf_event; | ||
try { | ||
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q); | ||
|
||
getrf_event = mkl_lapack::getrf( | ||
exec_q, | ||
n, // The order of the square matrix A (0 ≤ n). | ||
// It must be a non-negative integer. | ||
n, // The number of columns in the square matrix A (0 ≤ n). | ||
// It must be a non-negative integer. | ||
a, // Pointer to the square matrix A (n x n). | ||
lda, // The leading dimension of matrix A. | ||
// It must be at least max(1, n). | ||
ipiv, // Pointer to the output array of pivot indices. | ||
scratchpad, // Pointer to scratchpad memory to be used by MKL | ||
// routine for storing intermediate results. | ||
scratchpad_size, depends); | ||
} catch (mkl_lapack::exception const &e) { | ||
is_exception_caught = true; | ||
info = e.info(); | ||
|
||
if (info < 0) { | ||
error_msg << "Parameter number " << -info | ||
<< " had an illegal value."; | ||
} | ||
else if (info == scratchpad_size && e.detail() != 0) { | ||
error_msg | ||
<< "Insufficient scratchpad size. Required size is at least " | ||
<< e.detail(); | ||
} | ||
else if (info > 0) { | ||
// Store the positive 'info' value in the first element of | ||
// 'dev_info'. This indicates that the factorization has been | ||
// completed, but the factor U (upper triangular matrix) is exactly | ||
// singular. The 'info' value here is the index of the first zero | ||
// element in the diagonal of U. | ||
is_exception_caught = false; | ||
dev_info[0] = info; | ||
} | ||
else { | ||
error_msg << "Unexpected MKL exception caught during getrf() " | ||
"call:\nreason: " | ||
<< e.what() << "\ninfo: " << e.info(); | ||
} | ||
} catch (sycl::exception const &e) { | ||
is_exception_caught = true; | ||
error_msg << "Unexpected SYCL exception caught during getrf() call:\n" | ||
<< e.what(); | ||
} | ||
|
||
if (is_exception_caught) // an unexpected error occurs | ||
{ | ||
if (scratchpad != nullptr) { | ||
sycl::free(scratchpad, exec_q); | ||
} | ||
|
||
throw std::runtime_error(error_msg.str()); | ||
} | ||
|
||
sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) { | ||
cgh.depends_on(getrf_event); | ||
auto ctx = exec_q.get_context(); | ||
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); }); | ||
}); | ||
host_task_events.push_back(clean_up_event); | ||
return getrf_event; | ||
} | ||
|
||
std::pair<sycl::event, sycl::event> | ||
getrf(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
const std::vector<sycl::event> &depends) | ||
{ | ||
const int a_array_nd = a_array.get_ndim(); | ||
const int ipiv_array_nd = ipiv_array.get_ndim(); | ||
|
||
if (a_array_nd != 2) { | ||
throw py::value_error( | ||
"The input array has ndim=" + std::to_string(a_array_nd) + | ||
", but a 2-dimensional array is expected."); | ||
} | ||
|
||
if (ipiv_array_nd != 1) { | ||
throw py::value_error("The array of pivot indices has ndim=" + | ||
std::to_string(ipiv_array_nd) + | ||
", but a 1-dimensional array is expected."); | ||
} | ||
|
||
// check compatibility of execution queue and allocation queue | ||
if (!dpctl::utils::queues_are_compatible(exec_q, {a_array, ipiv_array})) { | ||
throw py::value_error( | ||
"Execution queue is not compatible with allocation queues"); | ||
} | ||
|
||
auto const &overlap = dpctl::tensor::overlap::MemoryOverlap(); | ||
if (overlap(a_array, ipiv_array)) { | ||
throw py::value_error("The input array and the array of pivot indices " | ||
"are overlapping segments of memory"); | ||
} | ||
|
||
bool is_a_array_c_contig = a_array.is_c_contiguous(); | ||
bool is_ipiv_array_c_contig = ipiv_array.is_c_contiguous(); | ||
if (!is_a_array_c_contig) { | ||
throw py::value_error("The input array " | ||
"must be C-contiguous"); | ||
} | ||
if (!is_ipiv_array_c_contig) { | ||
throw py::value_error("The array of pivot indices " | ||
"must be C-contiguous"); | ||
} | ||
|
||
auto array_types = dpctl_td_ns::usm_ndarray_types(); | ||
int a_array_type_id = | ||
array_types.typenum_to_lookup_id(a_array.get_typenum()); | ||
|
||
getrf_impl_fn_ptr_t getrf_fn = getrf_dispatch_vector[a_array_type_id]; | ||
if (getrf_fn == nullptr) { | ||
throw py::value_error( | ||
"No getrf implementation defined for the provided type " | ||
"of the input matrix."); | ||
} | ||
|
||
auto ipiv_types = dpctl_td_ns::usm_ndarray_types(); | ||
int ipiv_array_type_id = | ||
ipiv_types.typenum_to_lookup_id(ipiv_array.get_typenum()); | ||
|
||
if (ipiv_array_type_id != static_cast<int>(dpctl_td_ns::typenum_t::INT64)) { | ||
throw py::value_error("The type of 'ipiv_array' must be int64."); | ||
} | ||
|
||
const std::int64_t n = a_array.get_shape_raw()[0]; | ||
|
||
char *a_array_data = a_array.get_data(); | ||
const std::int64_t lda = std::max<size_t>(1UL, n); | ||
|
||
char *ipiv_array_data = ipiv_array.get_data(); | ||
std::int64_t *d_ipiv = reinterpret_cast<std::int64_t *>(ipiv_array_data); | ||
|
||
std::vector<sycl::event> host_task_events; | ||
sycl::event getrf_ev = getrf_fn(exec_q, n, a_array_data, lda, d_ipiv, | ||
dev_info, host_task_events, depends); | ||
|
||
sycl::event args_ev = dpctl::utils::keep_args_alive( | ||
exec_q, {a_array, ipiv_array}, host_task_events); | ||
|
||
return std::make_pair(args_ev, getrf_ev); | ||
} | ||
|
||
template <typename fnT, typename T> | ||
struct GetrfContigFactory | ||
{ | ||
fnT get() | ||
{ | ||
if constexpr (types::GetrfTypePairSupportFactory<T>::is_defined) { | ||
return getrf_impl<T>; | ||
} | ||
else { | ||
return nullptr; | ||
} | ||
} | ||
}; | ||
|
||
void init_getrf_dispatch_vector(void) | ||
{ | ||
dpctl_td_ns::DispatchVectorBuilder<getrf_impl_fn_ptr_t, GetrfContigFactory, | ||
dpctl_td_ns::num_types> | ||
contig; | ||
contig.populate_dispatch_vector(getrf_dispatch_vector); | ||
} | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
//***************************************************************************** | ||
// Copyright (c) 2023, Intel Corporation | ||
// All rights reserved. | ||
// | ||
// Redistribution and use in source and binary forms, with or without | ||
// modification, are permitted provided that the following conditions are met: | ||
// - Redistributions of source code must retain the above copyright notice, | ||
// this list of conditions and the following disclaimer. | ||
// - Redistributions in binary form must reproduce the above copyright notice, | ||
// this list of conditions and the following disclaimer in the documentation | ||
// and/or other materials provided with the distribution. | ||
// | ||
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | ||
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE | ||
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE | ||
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR | ||
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF | ||
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS | ||
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN | ||
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) | ||
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF | ||
// THE POSSIBILITY OF SUCH DAMAGE. | ||
//***************************************************************************** | ||
|
||
#pragma once | ||
|
||
#include <CL/sycl.hpp> | ||
#include <oneapi/mkl.hpp> | ||
|
||
#include <dpctl4pybind11.hpp> | ||
|
||
namespace dpnp | ||
{ | ||
namespace backend | ||
{ | ||
namespace ext | ||
{ | ||
namespace lapack | ||
{ | ||
extern std::pair<sycl::event, sycl::event> | ||
getrf(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
const std::vector<sycl::event> &depends = {}); | ||
|
||
extern std::pair<sycl::event, sycl::event> | ||
getrf_batch(sycl::queue exec_q, | ||
dpctl::tensor::usm_ndarray a_array, | ||
dpctl::tensor::usm_ndarray ipiv_array, | ||
py::list dev_info, | ||
std::int64_t n, | ||
std::int64_t stride_a, | ||
std::int64_t stride_ipiv, | ||
std::int64_t batch_size, | ||
const std::vector<sycl::event> &depends = {}); | ||
|
||
extern void init_getrf_dispatch_vector(void); | ||
extern void init_getrf_batch_dispatch_vector(void); | ||
} // namespace lapack | ||
} // namespace ext | ||
} // namespace backend | ||
} // namespace dpnp |
Oops, something went wrong.