Skip to content

Commit

Permalink
Merge pull request #1410 from kliegeois/gesv_1409
Browse files Browse the repository at this point in the history
Address #1409
  • Loading branch information
lucbv authored May 17, 2022
2 parents 3c46a88 + e137231 commit 97cc8e6
Show file tree
Hide file tree
Showing 3 changed files with 120 additions and 78 deletions.
35 changes: 17 additions & 18 deletions src/batched/dense/KokkosBatched_Gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ struct Gesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param tmp [in]: a rank 3 view used to store temporary variable; dimension
/// must be N x n x (n+4) where N is the batched size and n is the number of
/// rows.
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
/// \param tmp [in]: a rank 2 view used to store temporary variable; dimension
/// must be n x (n+4) where n is the number of rows.
///
///
/// Two versions are available (those are chosen based on ArgAlgo):
Expand Down Expand Up @@ -103,14 +102,14 @@ struct SerialGesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param member [in]: TeamPolicy member
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
///
/// Two versions are available (those are chosen based on ArgAlgo):
///
Expand Down Expand Up @@ -141,14 +140,14 @@ struct TeamGesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param member [in]: TeamPolicy member
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
///
/// Two versions are available (those are chosen based on ArgAlgo):
///
Expand Down
3 changes: 3 additions & 0 deletions src/batched/dense/KokkosBatched_LU_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,7 @@ struct LU {

} // namespace KokkosBatched

#include "KokkosBatched_LU_Serial_Impl.hpp"
#include "KokkosBatched_LU_Team_Impl.hpp"

#endif
160 changes: 100 additions & 60 deletions src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,20 @@ struct SerialGesv<Gesv::StaticPivoting> {
return 1;
}

SerialLU<Algo::Level3::Unblocked>::invoke(PDAD);
int r_val = SerialLU<Algo::Level3::Unblocked>::invoke(PDAD);

SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);

SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);

SerialHadamard1D(PDY, D2, X);
return 0;
if (r_val == 0) SerialHadamard1D(PDY, D2, X);
return r_val;
}
};

Expand Down Expand Up @@ -489,16 +493,21 @@ struct SerialGesv<Gesv::NoPivoting> {
}
#endif

SerialLU<Algo::Level3::Unblocked>::invoke(A);
int r_val = SerialLU<Algo::Level3::Unblocked>::invoke(A);

SerialCopy<Trans::NoTranspose, 1>::invoke(Y, X);
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);
if (r_val == 0) r_val = SerialCopy<Trans::NoTranspose, 1>::invoke(Y, X);

SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);

return 0;
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);

return r_val;
}
};

Expand Down Expand Up @@ -557,22 +566,31 @@ struct TeamGesv<MemberType, Gesv::StaticPivoting> {
}
member.team_barrier();

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
int r_val =
TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
member.team_barrier();

TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, PDAD,
PDY);
member.team_barrier();
if (r_val == 0) {
r_val = TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, PDAD,
PDY);
member.team_barrier();
if (r_val == 0) {
r_val =
TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamHadamard1D(member, PDY, D2, X);
member.team_barrier();
return 0;
if (r_val == 0) {
TeamHadamard1D(member, PDY, D2, X);
member.team_barrier();
}

return r_val;
}
};

Expand Down Expand Up @@ -605,21 +623,28 @@ struct TeamGesv<MemberType, Gesv::NoPivoting> {
}
#endif

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
int r_val = TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
member.team_barrier();

TeamCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
if (r_val == 0) {
TeamCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
if (r_val == 0) {
TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
if (r_val == 0) {
TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, A,
X);
member.team_barrier();
}

return 0;
return r_val;
}
};

Expand Down Expand Up @@ -679,22 +704,31 @@ struct TeamVectorGesv<MemberType, Gesv::StaticPivoting> {

member.team_barrier();

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
int r_val =
TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
member.team_barrier();

TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member,
1.0, PDAD,
PDY);
member.team_barrier();
}

TeamVectorHadamard1D(member, PDY, D2, X);
member.team_barrier();
return 0;
if (r_val == 0) {
TeamVectorHadamard1D(member, PDY, D2, X);
member.team_barrier();
}

return r_val;
}
};

Expand Down Expand Up @@ -727,23 +761,29 @@ struct TeamVectorGesv<MemberType, Gesv::NoPivoting> {
}
#endif

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
int r_val = TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
member.team_barrier();

TeamVectorCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
if (r_val == 0) {
TeamVectorCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A,
X);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
A, X);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
A, X);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member,
1.0, A, X);
member.team_barrier();
}

return 0;
return r_val;
}
};

Expand Down

0 comments on commit 97cc8e6

Please sign in to comment.