diff --git a/libclc/amdgcn-amdhsa/libspirv/SOURCES b/libclc/amdgcn-amdhsa/libspirv/SOURCES index ae7071ab41bea..43a10d58dffe4 100644 --- a/libclc/amdgcn-amdhsa/libspirv/SOURCES +++ b/libclc/amdgcn-amdhsa/libspirv/SOURCES @@ -1,5 +1,6 @@ workitem/get_global_offset.ll +group/group_ballot.cl group/collectives.cl group/collectives_helpers.ll atomic/loadstore_helpers.ll diff --git a/libclc/amdgcn-amdhsa/libspirv/group/group_ballot.cl b/libclc/amdgcn-amdhsa/libspirv/group/group_ballot.cl new file mode 100644 index 0000000000000..52b1f170a8576 --- /dev/null +++ b/libclc/amdgcn-amdhsa/libspirv/group/group_ballot.cl @@ -0,0 +1,33 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include +#include + +// from llvm/include/llvm/IR/InstrTypes.h +#define ICMP_NE 33 + +_CLC_DEF _CLC_CONVERGENT __clc_vec4_uint32_t +_Z29__spirv_GroupNonUniformBallotjb(unsigned flag, bool predicate) { + // only support subgroup for now + if (flag != Subgroup) { + __builtin_trap(); + __builtin_unreachable(); + } + + // prepare result, we only support the ballot operation on 64 threads maximum + // so we only need the first two elements to represent the final mask + __clc_vec4_uint32_t res; + res[2] = 0; + res[3] = 0; + + // run the ballot operation + res.xy = __builtin_amdgcn_uicmp((int)predicate, 0, ICMP_NE); + + return res; +} diff --git a/sycl/include/sycl/detail/helpers.hpp b/sycl/include/sycl/detail/helpers.hpp index 0301605892e02..8ab015674e499 100644 --- a/sycl/include/sycl/detail/helpers.hpp +++ b/sycl/include/sycl/detail/helpers.hpp @@ -82,8 +82,8 @@ class Builder { return group(Global, Local, Global / Local, Index); } - template - static ResType createSubGroupMask(uint32_t Bits, size_t BitsNum) { + template + static ResType createSubGroupMask(BitsType Bits, size_t BitsNum) { return ResType(Bits, BitsNum); } diff --git a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp index f02ea5158fcf8..4c2fbad0de4f7 100644 --- a/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp +++ b/sycl/include/sycl/ext/oneapi/sub_group_mask.hpp @@ -23,9 +23,19 @@ class Builder; namespace ext { namespace oneapi { +#if defined(__SYCL_DEVICE_ONLY__) && defined(__AMDGCN__) && \ + (__AMDGCN_WAVEFRONT_SIZE == 64) +#define BITS_TYPE uint64_t +#else +#define BITS_TYPE uint32_t +#endif + struct sub_group_mask { friend class detail::Builder; - static constexpr size_t max_bits = 32 /* implementation-defined */; + using BitsType = BITS_TYPE; + + static constexpr size_t max_bits = + sizeof(BitsType) * CHAR_BIT /* implementation-defined */; static constexpr size_t word_size = sizeof(uint32_t) * CHAR_BIT; // enable reference to individual bit @@ -55,9 +65,9 @@ struct sub_group_mask { private: // Reference to the word containing the bit - uint32_t &Ref; + BitsType &Ref; // Bit mask where only referenced bit is set - uint32_t RefBit; + BitsType RefBit; }; bool operator[](id<1> id) const { @@ -96,9 +106,9 @@ struct sub_group_mask { typename = sycl::detail::enable_if_t::value>> void insert_bits(Type bits, id<1> pos = 0) { size_t insert_size = sizeof(Type) * CHAR_BIT; - uint32_t insert_data = (uint32_t)bits; + BitsType insert_data = (BitsType)bits; insert_data <<= pos.get(0); - uint32_t mask = 0; + BitsType mask = 0; if (pos.get(0) + insert_size < size()) mask |= (valuable_bits(bits_num) << (pos.get(0) + insert_size)); if (pos.get(0) < size() && pos.get(0)) @@ -108,8 +118,8 @@ struct sub_group_mask { } /* The bits are stored in the memory in the following way: - marray id | 0 | 1 | 2 | 3 | - bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24| + marray id | 0 | 1 | 2 | 3 |... + bit id |7 .. 0|15 .. 8|23 .. 16|31 .. 24|... */ template ::value>> @@ -158,7 +168,7 @@ struct sub_group_mask { void set() { Bits = valuable_bits(bits_num); } void set(id<1> id, bool value = true) { operator[](id) = value; } - void reset() { Bits = uint32_t{0}; } + void reset() { Bits = BitsType{0}; } void reset(id<1> id) { operator[](id) = 0; } void reset_low() { reset(find_low()); } void reset_high() { reset(find_high()); } @@ -240,13 +250,17 @@ struct sub_group_mask { } private: - sub_group_mask(uint32_t rhs, size_t bn) : Bits(rhs), bits_num(bn) { + sub_group_mask(BitsType rhs, size_t bn) : Bits(rhs), bits_num(bn) { assert(bits_num <= max_bits); } - inline uint32_t valuable_bits(size_t bn) const { - return static_cast((1ULL << bn) - 1ULL); + inline BitsType valuable_bits(size_t bn) const { + assert(bn <= max_bits); + BitsType one = 1; + if (bn == max_bits) + return -one; + return (one << bn) - one; } - uint32_t Bits; + BitsType Bits; // Number of valuable bits size_t bits_num; }; @@ -259,8 +273,11 @@ group_ballot(Group g, bool predicate) { #ifdef __SYCL_DEVICE_ONLY__ auto res = __spirv_GroupNonUniformBallot( detail::spirv::group_scope::value, predicate); + BITS_TYPE val = res[0]; + if constexpr (sizeof(BITS_TYPE) == 8) + val |= ((BITS_TYPE)res[1]) << 32; return detail::Builder::createSubGroupMask( - res[0], g.get_max_local_range()[0]); + val, g.get_max_local_range()[0]); #else (void)predicate; throw exception{errc::feature_not_supported, @@ -268,6 +285,8 @@ group_ballot(Group g, bool predicate) { #endif } +#undef BITS_TYPE + } // namespace oneapi } // namespace ext } // __SYCL_INLINE_VER_NAMESPACE(_V1)