Skip to content

Commit

Permalink
[SYCL] Enable sub group masks for 64 bit subgroups (#7491)
Browse files Browse the repository at this point in the history
This patch is adding group ballot support for HIP (based on initial work
from @abagusetty on #6734 ), but also
extending the sub-group mask implementation to support 64 bit masks, as
a lot of AMD GPUs use 64 bit wavefronts.

Related to issue: #6718
  • Loading branch information
npmiller authored Dec 2, 2022
1 parent a578c81 commit 10d50ed
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 15 deletions.
1 change: 1 addition & 0 deletions libclc/amdgcn-amdhsa/libspirv/SOURCES
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@

workitem/get_global_offset.ll
group/group_ballot.cl
group/collectives.cl
group/collectives_helpers.ll
atomic/loadstore_helpers.ll
Expand Down
33 changes: 33 additions & 0 deletions libclc/amdgcn-amdhsa/libspirv/group/group_ballot.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include <spirv/spirv.h>
#include <spirv/spirv_types.h>

// from llvm/include/llvm/IR/InstrTypes.h
#define ICMP_NE 33

_CLC_DEF _CLC_CONVERGENT __clc_vec4_uint32_t
_Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) {
// only support subgroup for now
if (flag != Subgroup) {
__builtin_trap();
__builtin_unreachable();
}

// prepare result, we only support the ballot operation on 64 threads maximum
// so we only need the first two elements to represent the final mask
__clc_vec4_uint32_t res;
res[2] = 0;
res[3] = 0;

// run the ballot operation
res.xy = __builtin_amdgcn_uicmp((int)predicate, 0, ICMP_NE);

return res;
}
4 changes: 2 additions & 2 deletions sycl/include/sycl/detail/helpers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ class Builder {
return group<Dims>(Global, Local, Global / Local, Index);
}

template <class ResType>
static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) {
template <class ResType, typename BitsType>
static ResType createSubGroupMask(BitsType Bits, size_t BitsNum) {
return ResType(Bits, BitsNum);
}

Expand Down
45 changes: 32 additions & 13 deletions sycl/include/sycl/ext/oneapi/sub_group_mask.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,19 @@ class Builder;
namespace ext {
namespace oneapi {

#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \
(__AMDGCN_WAVEFRONT_SIZE == 64)
#define BITS_TYPE uint64_t
#else
#define BITS_TYPE uint32_t
#endif

struct sub_group_mask {
friend class detail::Builder;
static constexpr size_t max_bits = 32 /* implementation-defined */;
using BitsType = BITS_TYPE;

static constexpr size_t max_bits =
sizeof(BitsType) * CHAR_BIT /* implementation-defined */;
static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT;

// enable reference to individual bit
Expand Down Expand Up @@ -55,9 +65,9 @@ struct sub_group_mask {

private:
// Reference to the word containing the bit
uint32_t &Ref;
BitsType &Ref;
// Bit mask where only referenced bit is set
uint32_t RefBit;
BitsType RefBit;
};

bool operator[](id<1> id) const {
Expand Down Expand Up @@ -96,9 +106,9 @@ struct sub_group_mask {
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
void insert_bits(Type bits, id<1> pos = 0) {
size_t insert_size = sizeof(Type) * CHAR_BIT;
uint32_t insert_data = (uint32_t)bits;
BitsType insert_data = (BitsType)bits;
insert_data <<= pos.get(0);
uint32_t mask = 0;
BitsType mask = 0;
if (pos.get(0) + insert_size < size())
mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size));
if (pos.get(0) < size() && pos.get(0))
Expand All @@ -108,8 +118,8 @@ struct sub_group_mask {
}

/* The bits are stored in the memory in the following way:
marray id | 0 | 1 | 2 | 3 |
bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|
marray id | 0 | 1 | 2 | 3 |...
bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|...
*/
template <typename Type, size_t Size,
typename = sycl::detail::enable_if_t<std::is_integral<Type>::value>>
Expand Down Expand Up @@ -158,7 +168,7 @@ struct sub_group_mask {

void set() { Bits = valuable_bits(bits_num); }
void set(id<1> id, bool value = true) { operator[](id) = value; }
void reset() { Bits = uint32_t{0}; }
void reset() { Bits = BitsType{0}; }
void reset(id<1> id) { operator[](id) = 0; }
void reset_low() { reset(find_low()); }
void reset_high() { reset(find_high()); }
Expand Down Expand Up @@ -240,13 +250,17 @@ struct sub_group_mask {
}

private:
sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) {
sub_group_mask(BitsType rhs, size_t bn) : Bits(rhs), bits_num(bn) {
assert(bits_num <= max_bits);
}
inline uint32_t valuable_bits(size_t bn) const {
return static_cast<uint32_t>((1ULL << bn) - 1ULL);
inline BitsType valuable_bits(size_t bn) const {
assert(bn <= max_bits);
BitsType one = 1;
if (bn == max_bits)
return -one;
return (one << bn) - one;
}
uint32_t Bits;
BitsType Bits;
// Number of valuable bits
size_t bits_num;
};
Expand All @@ -259,15 +273,20 @@ group_ballot(Group g, bool predicate) {
#ifdef __SYCL_DEVICE_ONLY__
auto res = __spirv_GroupNonUniformBallot(
detail::spirv::group_scope<Group>::value, predicate);
BITS_TYPE val = res[0];
if constexpr (sizeof(BITS_TYPE) == 8)
val |= ((BITS_TYPE)res[1]) << 32;
return detail::Builder::createSubGroupMask<sub_group_mask>(
res[0], g.get_max_local_range()[0]);
val, g.get_max_local_range()[0]);
#else
(void)predicate;
throw exception{errc::feature_not_supported,
"Sub-group mask is not supported on host device"};
#endif
}

#undef BITS_TYPE

} // namespace oneapi
} // namespace ext
} // __SYCL_INLINE_VER_NAMESPACE(_V1)
Expand Down

0 comments on commit 10d50ed

Please sign in to comment.