Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Axpby bug fix (issue # 2015) #2039

Merged
merged 14 commits into from
Nov 24, 2023
29 changes: 25 additions & 4 deletions blas/impl/KokkosBlas1_axpby_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,13 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
// Nothing to do: m_y(i) = m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) =
Kokkos::ArithTraits<typename YV::non_const_value_type>::zero();
} else {
m_y(i) = m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -137,7 +143,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = -m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = -m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = -m_x(i);
} else {
m_y(i) = -m_x(i) + m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -151,7 +162,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = m_x(i);
} else {
m_y(i) = m_x(i) + m_b(0) * m_y(i);
}
}
}
// **************************************************************
Expand All @@ -165,7 +181,12 @@ struct Axpby_Functor {
} else if constexpr (scalar_y == 1) {
m_y(i) = m_a(0) * m_x(i) + m_y(i);
} else if constexpr (scalar_y == 2) {
m_y(i) = m_a(0) * m_x(i) + m_b(0) * m_y(i);
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
m_y(i) = m_a(0) * m_x(i);
} else {
m_y(i) = m_a(0) * m_x(i) + m_b(0) * m_y(i);
}
}
}
}
Expand Down
157 changes: 137 additions & 20 deletions blas/impl/KokkosBlas1_axpby_mv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,28 @@ struct Axpby_MV_Functor {
// Nothing to do: Y(i,j) := Y(i,j)
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = Kokkos::ArithTraits<
typename YMV::non_const_value_type>::zero();
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -181,14 +195,27 @@ struct Axpby_MV_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -239,14 +266,27 @@ struct Axpby_MV_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -334,14 +374,27 @@ struct Axpby_MV_Functor {
} else if constexpr (scalar_y == 2) {
if (m_a.extent(0) == 1) {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand All @@ -356,14 +409,27 @@ struct Axpby_MV_Functor {
}
} else {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
#pragma ivdep
#endif
#ifdef KOKKOS_ENABLE_PRAGMA_VECTOR
#pragma vector always
#endif
for (size_type k = 0; k < numCols; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_IVDEP
Expand Down Expand Up @@ -715,11 +781,22 @@ struct Axpby_MV_Unroll_Functor {
// Nothing to do: Y(i,j) := Y(i,j)
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = Kokkos::ArithTraits<
typename YMV::non_const_value_type>::zero();
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -758,11 +835,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = -m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -801,11 +888,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else if constexpr (scalar_y == 2) {
if (m_b.extent(0) == 1) {
if (m_b(0) ==
Kokkos::ArithTraits<typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down Expand Up @@ -872,11 +969,21 @@ struct Axpby_MV_Unroll_Functor {
} else if constexpr (scalar_y == 2) {
if (m_a.extent(0) == 1) {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(0) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand All @@ -888,11 +995,21 @@ struct Axpby_MV_Unroll_Functor {
}
} else {
if (m_b.extent(0) == 1) {
if (m_b(0) == Kokkos::ArithTraits<
typename BV::non_const_value_type>::zero()) {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k);
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
#pragma unroll
#endif
for (int k = 0; k < UNROLL; ++k) {
m_y(i, k) = m_a(k) * m_x(i, k) + m_b(0) * m_y(i, k);
}
}
} else {
#ifdef KOKKOS_ENABLE_PRAGMA_UNROLL
Expand Down
Loading