Skip to content

Commit

Permalink
[SYCL] group operations update to use size() (#7589)
Browse files Browse the repository at this point in the history
`get_size()` returns byte size. it is `size()` that is wanted in these
group operations.
  • Loading branch information
cperkinsintel authored Dec 2, 2022
1 parent 65b7501 commit 0fa7542
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
8 changes: 4 additions & 4 deletions sycl/include/sycl/detail/spirv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,7 @@ EnableIfNativeShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
template <typename T>
EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = SubgroupShuffle(x[s], local_id);
}
return result;
Expand All @@ -578,7 +578,7 @@ EnableIfVectorShuffle<T> SubgroupShuffle(T x, id<1> local_id) {
template <typename T>
EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = SubgroupShuffleXor(x[s], local_id);
}
return result;
Expand All @@ -587,7 +587,7 @@ EnableIfVectorShuffle<T> SubgroupShuffleXor(T x, id<1> local_id) {
template <typename T>
EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = SubgroupShuffleDown(x[s], delta);
}
return result;
Expand All @@ -596,7 +596,7 @@ EnableIfVectorShuffle<T> SubgroupShuffleDown(T x, uint32_t delta) {
template <typename T>
EnableIfVectorShuffle<T> SubgroupShuffleUp(T x, uint32_t delta) {
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = SubgroupShuffleUp(x[s], delta);
}
return result;
Expand Down
6 changes: 3 additions & 3 deletions sycl/include/sycl/ext/oneapi/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
typename Group::id_type local_id) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = broadcast(g, x[s], local_id);
}
return result;
Expand Down Expand Up @@ -212,7 +212,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
linear_local_id) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = broadcast(g, x[s], linear_local_id);
}
return result;
Expand Down Expand Up @@ -250,7 +250,7 @@ detail::enable_if_t<(detail::is_generic_group<Group>::value &&
T> broadcast(Group g, T x) {
#ifdef __SYCL_DEVICE_ONLY__
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = broadcast(g, x[s]);
}
return result;
Expand Down
12 changes: 6 additions & 6 deletions sycl/include/sycl/group_algorithm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ reduce_over_group(Group g, T x, BinaryOperation binary_op) {
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
"Result type of binary_op must match reduction accumulation type.");
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = reduce_over_group(g, x[s], binary_op);
}
return result;
Expand Down Expand Up @@ -280,7 +280,7 @@ reduce_over_group(Group g, V x, T init, BinaryOperation binary_op) {
"Result type of binary_op must match reduction accumulation type.");
#ifdef __SYCL_DEVICE_ONLY__
T result = init;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = binary_op(init[s], reduce_over_group(g, x[s], binary_op));
}
return result;
Expand Down Expand Up @@ -656,7 +656,7 @@ exclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
"Result type of binary_op must match scan accumulation type.");
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = exclusive_scan_over_group(g, x[s], binary_op);
}
return result;
Expand All @@ -680,7 +680,7 @@ exclusive_scan_over_group(Group g, V x, T init, BinaryOperation binary_op) {
std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
"Result type of binary_op must match scan accumulation type.");
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = exclusive_scan_over_group(g, x[s], init[s], binary_op);
}
return result;
Expand Down Expand Up @@ -823,7 +823,7 @@ inclusive_scan_over_group(Group g, T x, BinaryOperation binary_op) {
std::is_same<decltype(binary_op(x[0], x[0])), float>::value),
"Result type of binary_op must match scan accumulation type.");
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = inclusive_scan_over_group(g, x[s], binary_op);
}
return result;
Expand Down Expand Up @@ -917,7 +917,7 @@ inclusive_scan_over_group(Group g, V x, BinaryOperation binary_op, T init) {
std::is_same<decltype(binary_op(init[0], x[0])), float>::value),
"Result type of binary_op must match scan accumulation type.");
T result;
for (int s = 0; s < x.get_size(); ++s) {
for (int s = 0; s < x.size(); ++s) {
result[s] = inclusive_scan_over_group(g, x[s], binary_op, init[s]);
}
return result;
Expand Down

0 comments on commit 0fa7542

Please sign in to comment.