Skip to content

Commit

Permalink
Implement batched serial pttrf (kokkos#2256)
Browse files Browse the repository at this point in the history
* Batched serial pttrf implementation

* fix: use GEMM to add matrices

* fix: initialization order

* fformat

* fix: temporary variable in a test code

* fix: docstring of pttrf

* check_positive_definitiveness only if KOKKOSKERNELS_DEBUG_LEVEL > 0

* Improve the test for pttrf

* fix: int type

* fix: cleanup tests for SerialPttrf

* cleanup: remove unused deep_copies

* fix: docstrings and comments for pttrf

* ConjTranspose with conj and Transpose

* quick return in pttrf for size 1 or 0 matrix

* Add tests for invalid input

* fix: info computation

---------

Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi committed Jul 10, 2024
1 parent ea430c3 commit 994891a
Show file tree
Hide file tree
Showing 9 changed files with 909 additions and 0 deletions.
73 changes: 73 additions & 0 deletions batched/dense/impl/KokkosBatched_Pttrf_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Pttrf_Serial_Internal.hpp"

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

template <typename DViewType, typename EViewType>
KOKKOS_INLINE_FUNCTION static int checkPttrfInput(
[[maybe_unused]] const DViewType &d, [[maybe_unused]] const EViewType &e) {
static_assert(Kokkos::is_view<DViewType>::value,
"KokkosBatched::pttrf: DViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view<EViewType>::value,
"KokkosBatched::pttrf: EViewType is not a Kokkos::View.");

static_assert(DViewType::rank == 1,
"KokkosBatched::pttrf: DViewType must have rank 1.");
static_assert(EViewType::rank == 1,
"KokkosBatched::pttrf: EViewType must have rank 1.");

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int nd = d.extent(0);
const int ne = e.extent(0);

if (ne + 1 != nd) {
Kokkos::printf(
"KokkosBatched::pttrf: Dimensions of d and e do not match: d: %d, e: "
"%d \n"
"e.extent(0) must be equal to d.extent(0) - 1\n",
nd, ne);
return 1;
}
#endif
return 0;
}

template <>
struct SerialPttrf<Algo::Pttrf::Unblocked> {
template <typename DViewType, typename EViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d,
const EViewType &e) {
// Quick return if possible
if (d.extent(0) == 0) return 0;
if (d.extent(0) == 1) return (d(0) < 0 ? 1 : 0);

auto info = checkPttrfInput(d, e);
if (info) return info;

return SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke(
d.extent(0), d.data(), d.stride(0), e.data(), e.stride(0));
}
};
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PTTRF_SERIAL_IMPL_HPP_
211 changes: 211 additions & 0 deletions batched/dense/impl/KokkosBatched_Pttrf_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

template <typename AlgoType>
struct SerialPttrfInternal {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int n,
ValueType *KOKKOS_RESTRICT d,
const int ds0,
ValueType *KOKKOS_RESTRICT e,
const int es0);

template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0);
};

///
/// Real matrix
///

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke(
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0,
ValueType *KOKKOS_RESTRICT e, const int es0) {
int info = 0;

auto update = [&](const int i) {
auto ei_tmp = e[i * es0];
e[i * es0] = ei_tmp / d[i * ds0];
d[(i + 1) * ds0] -= e[i * es0] * ei_tmp;
};

auto check_positive_definitiveness = [&](const int i) {
return (d[i] <= 0.0) ? (i + 1) : 0;
};

// Compute the L*D*L' (or U'*D*U) factorization of A.
const int i4 = (n - 1) % 4;
for (int i = 0; i < i4; i++) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i);
if (info) {
return info;
}
#endif

update(i);
} // for (int i = 0; i < i4; i++)

for (int i = i4; i < n - 4; i += 4) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i);
if (info) {
return info;
}
#endif

update(i);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 1);
if (info) {
return info;
}
#endif

update(i + 1);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 2);
if (info) {
return info;
}
#endif

update(i + 2);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 3);
if (info) {
return info;
}
#endif

update(i + 3);

} // for (int i = i4; i < n-4; 4)

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(n - 1);
if (info) {
return info;
}
#endif

return 0;
}

///
/// Complex matrix
///

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPttrfInternal<Algo::Pttrf::Unblocked>::invoke(
const int n, ValueType *KOKKOS_RESTRICT d, const int ds0,
Kokkos::complex<ValueType> *KOKKOS_RESTRICT e, const int es0) {
int info = 0;

auto update = [&](const int i) {
auto eir_tmp = e[i * es0].real();
auto eii_tmp = e[i * es0].imag();
auto f_tmp = eir_tmp / d[i * ds0];
auto g_tmp = eii_tmp / d[i * ds0];
e[i * es0] = Kokkos::complex<ValueType>(f_tmp, g_tmp);
d[(i + 1) * ds0] = d[(i + 1) * ds0] - f_tmp * eir_tmp - g_tmp * eii_tmp;
};

auto check_positive_definitiveness = [&](const int i) {
return (d[i] <= 0.0) ? (i + 1) : 0;
};

// Compute the L*D*L' (or U'*D*U) factorization of A.
const int i4 = (n - 1) % 4;
for (int i = 0; i < i4; i++) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i);
if (info) {
return info;
}
#endif

update(i);
} // for (int i = 0; i < i4; i++)

for (int i = i4; i < n - 4; i += 4) {
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i);
if (info) {
return info;
}
#endif

update(i);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 1);
if (info) {
return info;
}
#endif

update(i + 1);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 2);
if (info) {
return info;
}
#endif

update(i + 2);

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(i + 3);
if (info) {
return info;
}
#endif

update(i + 3);

} // for (int i = i4; i < n-4; 4)

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
info = check_positive_definitiveness(n - 1);
if (info) {
return info;
}
#endif

return 0;
}

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PTTRF_SERIAL_INTERNAL_HPP_
52 changes: 52 additions & 0 deletions batched/dense/src/KokkosBatched_Pttrf.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PTTRF_HPP_
#define KOKKOSBATCHED_PTTRF_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

/// \brief Serial Batched Pttrf:
/// Compute the Cholesky factorization L*D*L**T (or L*D*L**H) of a real
/// symmetric (or complex Hermitian) positive definite tridiagonal matrix A_l
/// for all l = 0, ..., N
///
/// \tparam DViewType: Input type for the a diagonal matrix, needs to be a 1D
/// view
/// \tparam EViewType: Input type for the a upper/lower diagonal matrix,
/// needs to be a 1D view
///
/// \param d [inout]: n diagonal elements of the diagonal matrix D
/// \param e [inout]: n-1 upper/lower diagonal elements of the diagonal matrix E
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgAlgo>
struct SerialPttrf {
template <typename DViewType, typename EViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const DViewType &d,
const EViewType &e);
};

} // namespace KokkosBatched

#include "KokkosBatched_Pttrf_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_PTTRF_HPP_
3 changes: 3 additions & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@
#include "Test_Batched_SerialTrtri_Real.hpp"
#include "Test_Batched_SerialTrtri_Complex.hpp"
#include "Test_Batched_SerialSVD.hpp"
#include "Test_Batched_SerialPttrf.hpp"
#include "Test_Batched_SerialPttrf_Real.hpp"
#include "Test_Batched_SerialPttrf_Complex.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
40 changes: 40 additions & 0 deletions batched/dense/unit_test/Test_Batched_DenseUtils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,46 @@ void create_banded_triangular_matrix(InViewType& in, OutViewType& out,
}
Kokkos::deep_copy(out, h_out);
}

/// \brief Create a diagonal matrix from an input vector:
/// Copies the input vector into the diagonal of the output matrix specified
/// by the parameter k. k > 0 means that the matrix is upper-diagonal and
/// k < 0 means the lower-diagonal. k = 0 means the diagonal.
///
/// \tparam InViewType: Input type for the vector, needs to be a 2D view
/// \tparam OutViewType: Output type for the matrix, needs to be a 3D view
///
/// \param in [in]: Input batched vector, a rank 2 view
/// \param out [out]: Output batched matrix, where the diagonal compnent
/// specified by k is filled with the input vector, a rank 3 view
/// \param k [in]: The diagonal offset to be filled (default is 0).
///
template <typename InViewType, typename OutViewType>
void create_diagonal_matrix(InViewType& in, OutViewType& out, int k = 0) {
auto h_in = Kokkos::create_mirror_view(in);
auto h_out = Kokkos::create_mirror_view(out);
const int N = in.extent(0), BlkSize = in.extent(1);

assert(out.extent(0) == in.extent(0));
assert(out.extent(1) == in.extent(1) + abs(k));

int i1_start = k >= 0 ? 0 : -k;
int i2_start = k >= 0 ? k : 0;

// Zero clear the output matrix
using ScalarType = typename OutViewType::non_const_value_type;
Kokkos::deep_copy(h_out, ScalarType(0.0));

Kokkos::deep_copy(h_in, in);
for (int i0 = 0; i0 < N; i0++) {
for (int i1 = 0; i1 < BlkSize; i1++) {
h_out(i0, i1 + i1_start, i1 + i2_start) = h_in(i0, i1);
}
}

Kokkos::deep_copy(out, h_out);
}

} // namespace KokkosBatched

#endif // TEST_BATCHED_DENSE_HELPER_HPP
Loading

0 comments on commit 994891a

Please sign in to comment.