Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement of dpnp.linalg.cholesky() #1638

Merged
merged 31 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
0f7ba46
Add a new impl of dpnp.linalg.cholesky
vlad-perevezentsev Dec 7, 2023
ca9ce53
Add cupy tests for dpnp.linalg.cholesky
vlad-perevezentsev Dec 7, 2023
b558252
Add a batch impl of dpnp.linalg.cholesky
vlad-perevezentsev Dec 7, 2023
7c71b3d
Remove an old impl of dpnp_cholesky
vlad-perevezentsev Dec 7, 2023
57ea44a
Merge master into impl_cholesky
vlad-perevezentsev Dec 8, 2023
3a41236
Remove DPNP_FN_CHOLESKY_EXT in dpnp_iface_fptr
vlad-perevezentsev Dec 8, 2023
06fc207
Remove dpnp_cholesky_ext_c
vlad-perevezentsev Dec 8, 2023
8e60468
Add a new _dpnp_cholesky_batch func
vlad-perevezentsev Dec 8, 2023
98911fc
Update test_cholesky in test_sycl_queue
vlad-perevezentsev Dec 8, 2023
8eeeb09
Expand test scope in public CI
vlad-perevezentsev Dec 8, 2023
72d6443
Add more tests for dpnp.linalg.cholesky
vlad-perevezentsev Dec 8, 2023
703ecc3
Merge master into impl_cholesky
vlad-perevezentsev Jan 16, 2024
253dc93
Remove TODOs in cholesky() and update docstings
vlad-perevezentsev Jan 16, 2024
0613072
Use _common_type in dpnp_cholesky
vlad-perevezentsev Jan 16, 2024
1f87a65
Update dpnp_cholesky and dpnp_cholesky_batch
vlad-perevezentsev Jan 16, 2024
3c62207
Keep the lexicographic order
vlad-perevezentsev Jan 16, 2024
c90a006
Remove passing n parameter to _potrf
vlad-perevezentsev Jan 16, 2024
ad17d19
Add additional checks to potrf and potrf_batch
vlad-perevezentsev Jan 16, 2024
e36bdcb
Extend potrf error handler
vlad-perevezentsev Jan 16, 2024
8565346
Extend potrf_batch error handler
vlad-perevezentsev Jan 16, 2024
9d7411b
Update tests for dpnp.linalg.cholesky
vlad-perevezentsev Jan 16, 2024
933c320
Merge master into impl_cholesky
vlad-perevezentsev Jan 16, 2024
3d0484c
Update license year
vlad-perevezentsev Jan 17, 2024
f62ea45
Update cholesky docstrings
vlad-perevezentsev Jan 17, 2024
0668d74
Add support upper paramenetr for potrf
vlad-perevezentsev Jan 17, 2024
153f963
Add support upper paramenetr for potrf_batch and update dpnp_cholesky
vlad-perevezentsev Jan 18, 2024
48cc657
Add tests for upper parameter of dpnp.linalg.cholesky
vlad-perevezentsev Jan 18, 2024
547e029
Merge master into impl_cholesky
vlad-perevezentsev Jan 18, 2024
ca423f8
Address remarks
vlad-perevezentsev Jan 19, 2024
cf701c9
Fix validation check
vlad-perevezentsev Jan 19, 2024
f14ed63
Merge branch 'master' into impl_cholesky
vtavana Jan 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/conda-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ env:
test_umath.py
test_usm_type.py
third_party/cupy/core_tests
third_party/cupy/linalg_tests/test_decomposition.py
third_party/cupy/linalg_tests/test_norms.py
third_party/cupy/linalg_tests/test_product.py
third_party/cupy/linalg_tests/test_solve.py
Expand Down
2 changes: 2 additions & 0 deletions dpnp/backend/extensions/lapack/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ set(_module_src
${CMAKE_CURRENT_SOURCE_DIR}/getrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/getrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/heevd.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf.cpp
${CMAKE_CURRENT_SOURCE_DIR}/potrf_batch.cpp
${CMAKE_CURRENT_SOURCE_DIR}/syevd.cpp
)

Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/lapack/getrf.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/lapack/getrf.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down
2 changes: 1 addition & 1 deletion dpnp/backend/extensions/lapack/getrf_batch.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
//*****************************************************************************
// Copyright (c) 2023, Intel Corporation
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
Expand Down
17 changes: 17 additions & 0 deletions dpnp/backend/extensions/lapack/lapack_py.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "getrf.hpp"
#include "heevd.hpp"
#include "linalg_exceptions.hpp"
#include "potrf.hpp"
#include "syevd.hpp"

namespace lapack_ext = dpnp::backend::ext::lapack;
Expand All @@ -45,6 +46,8 @@ void init_dispatch_vectors(void)
lapack_ext::init_gesv_dispatch_vector();
lapack_ext::init_getrf_batch_dispatch_vector();
lapack_ext::init_getrf_dispatch_vector();
lapack_ext::init_potrf_batch_dispatch_vector();
lapack_ext::init_potrf_dispatch_vector();
lapack_ext::init_syevd_dispatch_vector();
}

Expand Down Expand Up @@ -92,6 +95,20 @@ PYBIND11_MODULE(_lapack_impl, m)
py::arg("eig_vecs"), py::arg("eig_vals"),
py::arg("depends") = py::list());

m.def("_potrf", &lapack_ext::potrf,
"Call `potrf` from OneMKL LAPACK library to return "
"the Cholesky factorization of a symmetric positive-definite matrix",
py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"),
py::arg("depends") = py::list());

m.def("_potrf_batch", &lapack_ext::potrf_batch,
"Call `potrf_batch` from OneMKL LAPACK library to return "
"the Cholesky factorization of a batch of symmetric "
"positive-definite matrix",
py::arg("sycl_queue"), py::arg("a_array"), py::arg("upper_lower"),
py::arg("n"), py::arg("stride_a"), py::arg("batch_size"),
py::arg("depends") = py::list());

m.def("_syevd", &lapack_ext::syevd,
"Call `syevd` from OneMKL LAPACK library to return "
"the eigenvalues and eigenvectors of a real symmetric matrix",
Expand Down
221 changes: 221 additions & 0 deletions dpnp/backend/extensions/lapack/potrf.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#include <pybind11/pybind11.h>

// dpctl tensor headers
#include "utils/memory_overlap.hpp"
#include "utils/type_utils.hpp"

#include "linalg_exceptions.hpp"
#include "potrf.hpp"
#include "types_matrix.hpp"

#include "dpnp_utils.hpp"

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
namespace mkl_lapack = oneapi::mkl::lapack;
namespace py = pybind11;
namespace type_utils = dpctl::tensor::type_utils;

typedef sycl::event (*potrf_impl_fn_ptr_t)(sycl::queue,
const oneapi::mkl::uplo,
const std::int64_t,
char *,
std::int64_t,
std::vector<sycl::event> &,
const std::vector<sycl::event> &);

static potrf_impl_fn_ptr_t potrf_dispatch_vector[dpctl_td_ns::num_types];

template <typename T>
static sycl::event potrf_impl(sycl::queue exec_q,
const oneapi::mkl::uplo upper_lower,
const std::int64_t n,
char *in_a,
std::int64_t lda,
std::vector<sycl::event> &host_task_events,
const std::vector<sycl::event> &depends)
{
type_utils::validate_type_for_device<T>(exec_q);

T *a = reinterpret_cast<T *>(in_a);

const std::int64_t scratchpad_size =
mkl_lapack::potrf_scratchpad_size<T>(exec_q, upper_lower, n, lda);
T *scratchpad = nullptr;

std::stringstream error_msg;
std::int64_t info = 0;
bool is_exception_caught = false;

sycl::event potrf_event;
try {
scratchpad = sycl::malloc_device<T>(scratchpad_size, exec_q);

potrf_event = mkl_lapack::potrf(
exec_q,
upper_lower, // An enumeration value of type oneapi::mkl::uplo:
// oneapi::mkl::uplo::upper for the upper triangular
// part; oneapi::mkl::uplo::lower for the lower
// triangular part.
n, // Order of the square matrix; (0 ≤ n).
a, // Pointer to the n-by-n matrix.
lda, // The leading dimension of `a`.
scratchpad, // Pointer to scratchpad memory to be used by MKL
// routine for storing intermediate results.
scratchpad_size, depends);
} catch (mkl_lapack::exception const &e) {
is_exception_caught = true;
info = e.info();
if (info < 0) {
error_msg << "Parameter number " << -info
<< " had an illegal value.";
}
else if (info == scratchpad_size && e.detail() != 0) {
error_msg
<< "Insufficient scratchpad size. Required size is at least "
<< e.detail();
}
else if (info > 0 && e.detail() == 0) {
sycl::free(scratchpad, exec_q);
throw LinAlgError("Matrix is not positive definite.");
}
else {
error_msg << "Unexpected MKL exception caught during getrf() "
"call:\nreason: "
<< e.what() << "\ninfo: " << e.info();
}
} catch (sycl::exception const &e) {
is_exception_caught = true;
error_msg << "Unexpected SYCL exception caught during potrf() call:\n"
<< e.what();
}

if (is_exception_caught) // an unexpected error occurs
{
if (scratchpad != nullptr) {
sycl::free(scratchpad, exec_q);
}
throw std::runtime_error(error_msg.str());
}

sycl::event clean_up_event = exec_q.submit([&](sycl::handler &cgh) {
cgh.depends_on(potrf_event);
auto ctx = exec_q.get_context();
cgh.host_task([ctx, scratchpad]() { sycl::free(scratchpad, ctx); });
});
host_task_events.push_back(clean_up_event);
return potrf_event;
}

std::pair<sycl::event, sycl::event>
potrf(sycl::queue q,
dpctl::tensor::usm_ndarray a_array,
const std::int8_t upper_lower,
const std::vector<sycl::event> &depends)
{
const int a_array_nd = a_array.get_ndim();

if (a_array_nd != 2) {
throw py::value_error(
"The input array has ndim=" + std::to_string(a_array_nd) +
", but a 2-dimensional array is expected.");
}

const py::ssize_t *a_array_shape = a_array.get_shape_raw();

if (a_array_shape[0] != a_array_shape[1]) {
throw py::value_error("The input array must be square,"
" but got a shape of (" +
std::to_string(a_array_shape[0]) + ", " +
std::to_string(a_array_shape[1]) + ").");
}

bool is_a_array_c_contig = a_array.is_c_contiguous();
if (!is_a_array_c_contig) {
throw py::value_error("The input array "
"must be C-contiguous");
}

auto array_types = dpctl_td_ns::usm_ndarray_types();
int a_array_type_id =
array_types.typenum_to_lookup_id(a_array.get_typenum());

potrf_impl_fn_ptr_t potrf_fn = potrf_dispatch_vector[a_array_type_id];
if (potrf_fn == nullptr) {
throw py::value_error(
"No potrf implementation defined for the provided type "
"of the input matrix.");
}

char *a_array_data = a_array.get_data();
const std::int64_t n = a_array_shape[0];
const std::int64_t lda = std::max<size_t>(1UL, n);
const oneapi::mkl::uplo uplo_val =
static_cast<oneapi::mkl::uplo>(upper_lower);

std::vector<sycl::event> host_task_events;
sycl::event potrf_ev =
potrf_fn(q, uplo_val, n, a_array_data, lda, host_task_events, depends);

sycl::event args_ev =
dpctl::utils::keep_args_alive(q, {a_array}, host_task_events);

return std::make_pair(args_ev, potrf_ev);
}

template <typename fnT, typename T>
struct PotrfContigFactory
{
fnT get()
{
if constexpr (types::PotrfTypePairSupportFactory<T>::is_defined) {
return potrf_impl<T>;
}
else {
return nullptr;
}
}
};

void init_potrf_dispatch_vector(void)
{
dpctl_td_ns::DispatchVectorBuilder<potrf_impl_fn_ptr_t, PotrfContigFactory,
dpctl_td_ns::num_types>
contig;
contig.populate_dispatch_vector(potrf_dispatch_vector);
}
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
61 changes: 61 additions & 0 deletions dpnp/backend/extensions/lapack/potrf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
//*****************************************************************************
// Copyright (c) 2024, Intel Corporation
// All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are met:
// - Redistributions of source code must retain the above copyright notice,
// this list of conditions and the following disclaimer.
// - Redistributions in binary form must reproduce the above copyright notice,
// this list of conditions and the following disclaimer in the documentation
// and/or other materials provided with the distribution.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
// AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
// ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
// LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
// CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
// SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
// INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
// CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
// ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
// THE POSSIBILITY OF SUCH DAMAGE.
//*****************************************************************************

#pragma once

#include <CL/sycl.hpp>
#include <oneapi/mkl.hpp>

#include <dpctl4pybind11.hpp>

namespace dpnp
{
namespace backend
{
namespace ext
{
namespace lapack
{
extern std::pair<sycl::event, sycl::event>
potrf(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
const std::int8_t upper_lower,
const std::vector<sycl::event> &depends = {});

extern std::pair<sycl::event, sycl::event>
potrf_batch(sycl::queue exec_q,
dpctl::tensor::usm_ndarray a_array,
const std::int8_t upper_lower,
const std::int64_t n,
const std::int64_t stride_a,
const std::int64_t batch_size,
const std::vector<sycl::event> &depends = {});

extern void init_potrf_dispatch_vector(void);
extern void init_potrf_batch_dispatch_vector(void);
} // namespace lapack
} // namespace ext
} // namespace backend
} // namespace dpnp
Loading