Skip to content

Commit

Permalink
Merge pull request #1122 from e10harvey/issue1121
Browse files Browse the repository at this point in the history
Fix BatchedDlbBuf build with cuda-9 and XL toolchains
  • Loading branch information
e10harvey authored Oct 2, 2021
2 parents 1274344 + dfcd8a1 commit ce5e984
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 16 deletions.
2 changes: 1 addition & 1 deletion example/gmres/gmres.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Kokkos::Experimental::half_t abs(Kokkos::Experimental::half_t arg) {
return arg < 0.0 ? -arg : arg;
}

Kokkos::complex<Kokkos::Experimental::half_t> abs(Kokkos::complex<Kokkos::Experimental::half_t> arg) {
Kokkos::complex<Kokkos::Experimental::half_t> abs(Kokkos::complex<Kokkos::Experimental::half_t> arg) noexcept {
return Kokkos::complex<Kokkos::Experimental::half_t>(abs(Kokkos::complex<double>((double) arg.real(), (double) arg.imag())));
}
#endif // KOKKOS_HALF_T_IS_FLOAT
Expand Down
4 changes: 2 additions & 2 deletions src/batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,7 @@ namespace KokkosBatched {
struct Rank2 {};
};

/// \brief BoundsCheck class used to specific whether to check view bounds in
/// \brief BoundsCheck class used to specify whether to check view bounds in
/// BLAS/LAPACK DblBuf algorithms.
/// /var Yes Use functor with bounds check
/// /var No Use functor without bound checks
Expand Down Expand Up @@ -668,7 +668,7 @@ namespace KokkosBatched {

template <class ViewType>
KOKKOS_INLINE_FUNCTION auto transpose_2d_view(ViewType v, const int *order) {
const int rank = 2;
constexpr int rank = 2;
const int dim[] = {v.extent_int(1), v.extent_int(0)};
using view_value_type = typename ViewType::value_type;
using execution_space_type = typename ViewType::execution_space;
Expand Down
1 change: 0 additions & 1 deletion src/batched/dense/KokkosBatched_Gemm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
? (c_m >= 16)
: (c_m >= 24 && c_m <= 32) || (c_m >= 45 && c_m <= 64))) {
handle->teamSz = handle->vecLen = 8;
// constexpr int tile_m = 32, tile_n = 32, tile_k = 8;
constexpr int tile_m = 32, tile_n = 32, tile_k = 8;
if (c_m % 32 == 0) // No bounds checking
ret =
Expand Down
10 changes: 6 additions & 4 deletions src/batched/dense/impl/KokkosBatched_Gemm_DblBuf_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,7 @@ class BatchedDblBufGemm {
constexpr int reg_n = TILE_N / TILE_K + 2 * !!(TILE_N % TILE_K);
constexpr int stride_m = TILE_K;
constexpr int stride_n = TILE_N / reg_n;
using functor_type =
__Functor<member_type, reg_m, reg_n, stride_m, stride_n>;
using functor_type = Functor<member_type, reg_m, reg_n, stride_m, stride_n>;

functor_type functor(*this, __A, __B, __C, TILE_M, TILE_N, TILE_K);

Expand Down Expand Up @@ -182,8 +181,11 @@ class BatchedDblBufGemm {
Kokkos::parallel_for("BatchedDblBufGemm", team_policy, functor);
}

public:
// Make Functor public for cuda 9.
// See https://github.com/kokkos/kokkos-kernels/issues/1121.
template <class MemberType, int REG_M, int REG_N, int STRIDE_M, int STRIDE_N>
class __Functor {
class Functor {
private:
BatchedDblBufGemm &__ei;
AViewType __A;
Expand All @@ -201,7 +203,7 @@ class BatchedDblBufGemm {
// below. If those are used, we get an invalid memory error from cuda. I
// suspect this is due the values not being copied to device and then
// runtime resolution of the host address &__ei.
__Functor(BatchedDblBufGemm &ei, AViewType A, BViewType B, CViewType C,
Functor(BatchedDblBufGemm &ei, AViewType A, BViewType B, CViewType C,
unsigned tile_m = 1, unsigned tile_n = 1, unsigned tile_k = 1)
: __ei(ei),
__A(A),
Expand Down
8 changes: 0 additions & 8 deletions unit_test/batched/dense/Test_Batched_BatchedGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,6 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle,
}
}
}
// std::cout << "algo_type:" << algo_type << std::endl;
// std::cout << "C0:" << matCdim1 << ", C1:" << matCdim2 << std::endl;
// std::cout << "A0:" << matAdim1 << ", A1:" << matAdim2 << std::endl;
// std::cout << "B0:" << matBdim1 << ", B1:" << matBdim2 << std::endl;
EXPECT_NEAR_KK(diff / sum, 0, eps);
}

Expand Down Expand Up @@ -198,10 +194,6 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,
{
BatchedGemmHandle batchedGemmHandle(algo_type);

// batchedGemmHandle.enableDebug = true;
// std::cout << "Testing algo_type = " << algo_type << "/" <<
// GemmKokkosBatchedAlgos::N << std::endl;

ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(), algo_type);

if (algo_type == BaseKokkosBatchedAlgos::KK_SERIAL ||
Expand Down

0 comments on commit ce5e984

Please sign in to comment.