Skip to content

Commit

Permalink
[SYCL][CUDA] Allow joint_matrix to be loaded from const T (#6532)
Browse files Browse the repository at this point in the history
Fixes a bug where if `joint_matrix_load` attempts to load `joint_matrix`
from an array of `const T`incorrect behaviour will occur or an error
will be thrown. To fix this we make use of `std::remove_const_t<T>` in
appropriate places. This is important functionality for integrating
joint_matrix with existing SYCL-DNN routines.
I think that similar problems might occur in the intel backends for
their existing impl: I have not made corresponding changes because I do
not have the hardware to test it.

Signed-off-by: JackAKirk <jack.kirk@codeplay.com>
  • Loading branch information
JackAKirk authored Oct 5, 2022
1 parent 813ca36 commit 134618f
Showing 1 changed file with 55 additions and 53 deletions.
108 changes: 55 additions & 53 deletions sycl/include/sycl/ext/oneapi/matrix/matrix-tensorcore.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,10 +219,11 @@ struct joint_matrix_load_impl<
void load(sycl::ext::oneapi::experimental::matrix::joint_matrix<
S, Use, NumRows, NumCols, Layout, sycl::sub_group> &res,
multi_ptr<T, Space> src, size_t stride) {
if constexpr (std::is_same<T, uint16_t>::value ||
if constexpr (std::is_same<std::remove_const_t<T>, uint16_t>::value ||
std::is_same<
T, sycl::ext::oneapi::experimental::bfloat16>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
std::remove_const_t<T>,
sycl::ext::oneapi::experimental::bfloat16>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -247,8 +248,8 @@ struct joint_matrix_load_impl<
__mma_bf16_m32n8k16_ld_b(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, uint8_t>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, uint8_t>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -273,8 +274,8 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_b_u8(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, int8_t>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, int8_t>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand All @@ -299,8 +300,8 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_b_s8(destptr, tileptr, stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, half>::value) {
auto tileptr = reinterpret_cast<int32_t const *>(src.get());
} else if constexpr (std::is_same<std::remove_const_t<T>, half>::value) {
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
if constexpr (Use ==
Expand Down Expand Up @@ -332,7 +333,7 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}

} else if constexpr (std::is_same<T, int32_t>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, int32_t>::value) {
auto destptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
__imma_m16n16k16_ld_c(destptr, src.get(), stride,
Expand All @@ -344,7 +345,7 @@ struct joint_matrix_load_impl<
__imma_m32n8k16_ld_c(destptr, src.get(), stride,
get_layout_id<Layout>());
}
} else if constexpr (std::is_same<T, float>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, float>::value) {
if constexpr (std::is_same<S, float>::value) {
auto dstptr = reinterpret_cast<float *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 16) {
Expand All @@ -360,7 +361,7 @@ struct joint_matrix_load_impl<
} else if constexpr (std::is_same<S,
sycl::ext::oneapi::experimental::
matrix::precision::tf32>::value) {
auto tileptr = reinterpret_cast<int32_t *>(src.get());
auto tileptr = reinterpret_cast<const int32_t *>(src.get());
auto dstptr = reinterpret_cast<int32_t *>(&res.wi_marray);
if constexpr (NumRows == 16 && NumCols == 8) {
__mma_tf32_m16n16k8_ld_a(dstptr, tileptr, stride,
Expand All @@ -370,7 +371,7 @@ struct joint_matrix_load_impl<
get_layout_id<Layout>());
}
}
} else if constexpr (std::is_same<T, double>::value) {
} else if constexpr (std::is_same<std::remove_const_t<T>, double>::value) {
auto dstptr = reinterpret_cast<double *>(&res.wi_marray);
if constexpr (Use ==
sycl::ext::oneapi::experimental::matrix::matrix_use::a) {
Expand Down Expand Up @@ -560,9 +561,9 @@ struct joint_matrix_mad_impl<
D;
if constexpr (M == 16 && N == 16 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m16n16k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -572,34 +573,34 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m16n16k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m16n16k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
__mma_bf16_m16n16k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (M == 8 && N == 32 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m8n32k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -609,34 +610,34 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m8n32k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m8n32k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (std::is_same<T1, uint16_t>::value ||
std::is_same<T1, sycl::ext::oneapi::experimental::
bfloat16>::value) {
__mma_bf16_m8n32k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
} else if constexpr (M == 32 && N == 8 && K == 16) {
if constexpr (std::is_same<T2, int32_t>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrC = reinterpret_cast<int32_t const *>(&C.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
auto ptrC = reinterpret_cast<const int32_t *>(&C.wi_marray);
auto ptrD = reinterpret_cast<int32_t *>(&D.wi_marray);
if constexpr (std::is_same<T1, int8_t>::value) {
__imma_m32n8k16_mma_s8(ptrD, ptrA, ptrB, ptrC,
Expand All @@ -650,22 +651,22 @@ struct joint_matrix_mad_impl<
bfloat16>::value) {
__mma_bf16_m32n8k16_mma_f32(
reinterpret_cast<float *>(&D.wi_marray),
reinterpret_cast<int32_t const *>(&A.wi_marray),
reinterpret_cast<int32_t const *>(&B.wi_marray),
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&A.wi_marray),
reinterpret_cast<const int32_t *>(&B.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, half>::value) {
auto ptrA = reinterpret_cast<int32_t const *>(&A.wi_marray);
auto ptrB = reinterpret_cast<int32_t const *>(&B.wi_marray);
auto ptrA = reinterpret_cast<const int32_t *>(&A.wi_marray);
auto ptrB = reinterpret_cast<const int32_t *>(&B.wi_marray);
if constexpr (std::is_same<T2, float>::value) {
__hmma_m32n8k16_mma_f32f32(
reinterpret_cast<float *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<float const *>(&C.wi_marray),
reinterpret_cast<const float *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T2, half>::value) {
__hmma_m32n8k16_mma_f16f16(
reinterpret_cast<int32_t *>(&D.wi_marray), ptrA, ptrB,
reinterpret_cast<int32_t const *>(&C.wi_marray),
reinterpret_cast<const int32_t *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
}
Expand All @@ -677,9 +678,9 @@ struct joint_matrix_mad_impl<
get_layout_pair_id<LayoutA, LayoutB>(), 0);
} else if constexpr (std::is_same<T1, double>::value) {
__dmma_m8n8k4_mma_f64(reinterpret_cast<double *>(&D.wi_marray),
reinterpret_cast<double const *>(&A.wi_marray),
reinterpret_cast<double const *>(&B.wi_marray),
reinterpret_cast<double const *>(&C.wi_marray),
reinterpret_cast<const double *>(&A.wi_marray),
reinterpret_cast<const double *>(&B.wi_marray),
reinterpret_cast<const double *>(&C.wi_marray),
get_layout_pair_id<LayoutA, LayoutB>(), 0);
}
return D;
Expand All @@ -692,13 +693,14 @@ struct joint_matrix_mad_impl<
namespace experimental {
namespace matrix {

template <typename Group, typename S, typename T, matrix_use Use,
size_t NumRows, size_t NumCols, matrix_layout Layout,
access::address_space Space,
std::enable_if_t<std::is_same<S, T>::value ||
(std::is_same<S, precision::tf32>::value &&
std::is_same<T, float>::value),
bool> = true>
template <
typename Group, typename S, typename T, matrix_use Use, size_t NumRows,
size_t NumCols, matrix_layout Layout, access::address_space Space,
std::enable_if_t<std::is_same<S, std::remove_const_t<T>>::value ||
(std::is_same<S, precision::tf32>::value &&

std::is_same<std::remove_const_t<T>, float>::value),
bool> = true>
void joint_matrix_load(
Group sg, joint_matrix<S, Use, NumRows, NumCols, Layout, Group> &res,
multi_ptr<T, Space> src, size_t stride) {
Expand Down

0 comments on commit 134618f

Please sign in to comment.