Skip to content

Commit

Permalink
[SYCL][ESIMD] Add support for different types for lsc functions (#6952)
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 authored Oct 6, 2022
1 parent 6b24fdc commit d9e40ec
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 35 deletions.
3 changes: 1 addition & 2 deletions llvm/lib/SYCLLowerIR/ESIMD/ESIMDVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,7 @@ static const char *LegalSYCLFunctions[] = {
"^sycl::_V1::ext::oneapi::sub_group::.+",
"^sycl::_V1::ext::oneapi::experimental::spec_constant<.+>::.+",
"^sycl::_V1::ext::oneapi::experimental::this_sub_group",
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+",
"^sycl::_V1::ext::oneapi::experimental::tfloat32::.+"};
"^sycl::_V1::ext::oneapi::experimental::bfloat16::.+"};

static const char *LegalSYCLFunctionsInStatelessMode[] = {
"^sycl::_V1::multi_ptr<.+>::get",
Expand Down
17 changes: 10 additions & 7 deletions sycl/include/sycl/ext/intel/experimental/esimd/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,17 +185,20 @@ constexpr lsc_data_size expand_data_size(lsc_data_size DS) {
}

template <typename T> struct lsc_expand_type {
using type = typename std::conditional<sizeof(T) < 4, uint32_t, T>::type;
using type = std::conditional_t<
sizeof(T) <= 4,
std::conditional_t<std::is_signed<T>::value, int32_t, uint32_t>,
std::conditional_t<std::is_signed<T>::value, int64_t, uint64_t>>;
};

template <typename T> struct lsc_bitcast_type {
private:
using _type1 = typename std::conditional<sizeof(T) == 2, uint16_t, T>::type;
using _type2 = typename std::conditional<sizeof(T) == 1, uint8_t, T>::type;

public:
using type =
typename std::conditional<sizeof(_type2) == 1, _type2, _type1>::type;
using type = std::conditional_t<
sizeof(T) == 1, uint8_t,
std::conditional_t<
sizeof(T) == 2, uint16_t,
std::conditional_t<sizeof(T) == 4, uint32_t,
std::conditional_t<sizeof(T) == 8, uint64_t, T>>>>;
};

} // namespace detail
Expand Down
96 changes: 70 additions & 26 deletions sycl/include/sycl/ext/intel/experimental/esimd/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,8 @@ __ESIMD_API std::enable_if_t<!std::is_pointer<AccessorTy>::value,
lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
__ESIMD_NS::simd_mask<N> pred = 1) {
#ifdef __ESIMD_FORCE_STATELESS_MEM
return lsc_gather<T, N, DS, L1H>(acc.get_pointer().get(), offsets, pred);
return lsc_gather<T, NElts, DS, L1H, L3H>(acc.get_pointer().get(), offsets,
pred);
#else
detail::check_lsc_vector_size<NElts>();
detail::check_lsc_data_size<T, DS>();
Expand Down Expand Up @@ -478,11 +479,11 @@ lsc_gather(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
/// given address, where S is a byte size of an "element" defined by the \c DS
/// template parameter. The maximum size of accessed block is 512 bytes for PVC
/// and 256 bytes for ACM (DG2).
/// When \? DS equals \? lsc_data_size::u64, the address must be 8-byte aligned,
/// When \c DS equals \c lsc_data_size::u64, the address must be 8-byte aligned,
/// otherwise - 4-bytes aligned. Allowed values for the data size are
/// \? lsc_data_size::u32 and \? lsc_data_size::u64. Allowed NElts values are
/// \c lsc_data_size::u32 and \c lsc_data_size::u64. Allowed NElts values are
/// 1, 2, 3, 4, 8, 16, 32, 64.
/// Note that to access 512 bytes, DS must be \? lsc_data_size::u64 and \c NElts
/// Note that to access 512 bytes, DS must be \c lsc_data_size::u64 and \c NElts
/// must be 64.
///
/// @tparam T is element type.
Expand Down Expand Up @@ -518,9 +519,19 @@ lsc_block_load(const T *p, __ESIMD_NS::simd_mask<1> pred = 1) {
constexpr detail::lsc_vector_size _VS =
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
if constexpr (SmallIntFactor == 1) {
return __esimd_lsc_load_stateless<T, L1H, L3H, _AddressScale, _ImmOffset,
_DS, _VS, _Transposed, N>(pred.data(),
addrs.data());
if constexpr (_DS == lsc_data_size::u32) {
__ESIMD_NS::simd<uint32_t, NElts> result =
__esimd_lsc_load_stateless<uint32_t, L1H, L3H, _AddressScale,
_ImmOffset, lsc_data_size::u32, _VS,
_Transposed, N>(pred.data(), addrs.data());
return result.template bit_cast_view<T>();
} else {
__ESIMD_NS::simd<uint64_t, NElts> result =
__esimd_lsc_load_stateless<uint64_t, L1H, L3H, _AddressScale,
_ImmOffset, lsc_data_size::u64, _VS,
_Transposed, N>(pred.data(), addrs.data());
return result.template bit_cast_view<T>();
}
} else {
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> result =
__esimd_lsc_load_stateless<uint32_t, L1H, L3H, _AddressScale,
Expand Down Expand Up @@ -582,11 +593,20 @@ lsc_block_load(AccessorTy acc, uint32_t offset,
detail::to_lsc_vector_size<NElts / SmallIntFactor>();

if constexpr (SmallIntFactor == 1) {
return __esimd_lsc_load_bti<T, L1H, L3H, _AddressScale, _ImmOffset, _DS,
_VS, _Transposed, N>(pred.data(),
offsets.data(), si);
if constexpr (_DS == lsc_data_size::u32) {
__ESIMD_NS::simd<uint32_t, NElts> result =
__esimd_lsc_load_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
lsc_data_size::u32, _VS, _Transposed, N>(
pred.data(), offsets.data(), si);
return result.template bit_cast_view<T>();
} else {
__ESIMD_NS::simd<uint64_t, NElts> result =
__esimd_lsc_load_bti<uint64_t, L1H, L3H, _AddressScale, _ImmOffset,
lsc_data_size::u64, _VS, _Transposed, N>(
pred.data(), offsets.data(), si);
return result.template bit_cast_view<T>();
}
} else {

__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> result =
__esimd_lsc_load_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
lsc_data_size::u32, _VS, _Transposed, N>(
Expand Down Expand Up @@ -904,8 +924,8 @@ lsc_scatter(AccessorTy acc, __ESIMD_NS::simd<uint32_t, N> offsets,
__ESIMD_NS::simd<T, N * NElts> vals,
__ESIMD_NS::simd_mask<N> pred = 1) {
#ifdef __ESIMD_FORCE_STATELESS_MEM
lsc_scatter<T, NElts, DS, L1H>(__ESIMD_DNS::accessorToPointer<T>(acc),
offsets, pred);
lsc_scatter<T, NElts, DS, L1H, L3H>(__ESIMD_DNS::accessorToPointer<T>(acc),
offsets, vals, pred);
#else
detail::check_lsc_vector_size<NElts>();
detail::check_lsc_data_size<T, DS>();
Expand Down Expand Up @@ -967,13 +987,23 @@ __ESIMD_API void lsc_block_store(T *p, __ESIMD_NS::simd<T, NElts> vals,
constexpr detail::lsc_vector_size _VS =
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
if constexpr (SmallIntFactor == 1) {

__esimd_lsc_store_stateless<T, L1H, L3H, _AddressScale, _ImmOffset, _DS,
_VS, _Transposed, N>(pred.data(), addrs.data(),
vals.data());
if constexpr (_DS == lsc_data_size::u32) {
__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
_DS, _VS, _Transposed, N>(
pred.data(), addrs.data(),
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
vals.data()));
} else {
__esimd_lsc_store_stateless<uint64_t, L1H, L3H, _AddressScale, _ImmOffset,
_DS, _VS, _Transposed, N>(
pred.data(), addrs.data(),
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
vals.data()));
}
} else {
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> tmp =
vals.template bit_cast_view<uint32_t>();
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> tmp = sycl::bit_cast<
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
vals.data());

__esimd_lsc_store_stateless<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
lsc_data_size::u32, _VS, _Transposed, N>(
Expand Down Expand Up @@ -1010,7 +1040,7 @@ lsc_block_store(AccessorTy acc, uint32_t offset,
__ESIMD_NS::simd<T, NElts> vals,
__ESIMD_NS::simd_mask<1> pred = 1) {
#ifdef __ESIMD_FORCE_STATELESS_MEM
lsc_block_store<T, NElts, DS, L1H>(
lsc_block_store<T, NElts, DS, L1H, L3H>(
__ESIMD_DNS::accessorToPointer<T>(acc, offset), vals, pred);
#else
detail::check_lsc_data_size<T, DS>();
Expand All @@ -1033,15 +1063,29 @@ lsc_block_store(AccessorTy acc, uint32_t offset,
constexpr detail::lsc_vector_size _VS =
detail::to_lsc_vector_size<NElts / SmallIntFactor>();
if constexpr (SmallIntFactor > 1) {
__ESIMD_NS::simd<uint32_t, NElts / SmallIntFactor> Tmp =
vals.template bit_cast_view<uint32_t>();
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset,
lsc_data_size::u32, _VS, _Transposed, N>(
pred.data(), offsets.data(), Tmp.data(), si);
pred.data(), offsets.data(),
sycl::bit_cast<
__ESIMD_DNS::vector_type_t<uint32_t, NElts / SmallIntFactor>>(
vals.data()),
si);
} else {
__esimd_lsc_store_bti<T, L1H, L3H, _AddressScale, _ImmOffset, _DS, _VS,
_Transposed, N>(pred.data(), offsets.data(),
vals.data(), si);
if constexpr (_DS == lsc_data_size::u32) {
__esimd_lsc_store_bti<uint32_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
_VS, _Transposed, N>(
pred.data(), offsets.data(),
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint32_t, NElts>>(
vals.data()),
si);
} else {
__esimd_lsc_store_bti<uint64_t, L1H, L3H, _AddressScale, _ImmOffset, _DS,
_VS, _Transposed, N>(
pred.data(), offsets.data(),
sycl::bit_cast<__ESIMD_DNS::vector_type_t<uint64_t, NElts>>(
vals.data()),
si);
}
}
#endif
}
Expand Down

0 comments on commit d9e40ec

Please sign in to comment.