-
Notifications
You must be signed in to change notification settings - Fork 739
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SYCL][Reduction] Fix range deductions in reduction parallel_for (#9813)
Currently `parallel_for` taking reduction variables take the dimensionality as a template argument. This means the range deduction guides cannot be used. This commit amends this by splitting them into separate function definitions per valid dimensionality. This adheres to [4.9.4.2.2. parallel_for invoke](https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#_parallel_for_invoke): > The parallel_for overload without an offset can be called with either a number or a braced-init-list with 1-3 elements. In that case the following calls are equivalent: > - parallel_for(N, some_kernel) has same effect as parallel_for(range<1>(N), some_kernel) > - parallel_for({N}, some_kernel) has same effect as parallel_for(range<1>(N), some_kernel) > - parallel_for({N1, N2}, some_kernel) has same effect as parallel_for(range<2>(N1, N2), some_kernel) > - parallel_for({N1, N2, N3}, some_kernel) has same effect as parallel_for(range<3>(N1, N2, N3), some_kernel) --------- Signed-off-by: Larsen, Steffen <steffen.larsen@intel.com>
- Loading branch information
1 parent
6d2a46a
commit 3024161
Showing
2 changed files
with
189 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
147 changes: 147 additions & 0 deletions
147
sycl/test/basic_tests/reduction/reduction_range_deduction.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,147 @@ | ||
// RUN: %clangxx -fsycl -fsyntax-only %s | ||
|
||
// Tests reduction parallel_for can use SYCL 2020 range deduction guides. | ||
|
||
#include <sycl/sycl.hpp> | ||
|
||
template <class T> struct PlusWithoutIdentity { | ||
T operator()(const T &A, const T &B) const { return A + B; } | ||
}; | ||
|
||
int main() { | ||
sycl::queue Q; | ||
|
||
int *ScalarMem = sycl::malloc_shared<int>(1, Q); | ||
int *SpanMem = sycl::malloc_shared<int>(8, Q); | ||
auto ScalarRed1 = sycl::reduction(ScalarMem, std::plus<int>{}); | ||
auto ScalarRed2 = sycl::reduction(ScalarMem, PlusWithoutIdentity<int>{}); | ||
auto SpanRed1 = | ||
sycl::reduction(sycl::span<int, 8>{SpanMem, 8}, std::plus<int>{}); | ||
auto SpanRed2 = sycl::reduction(sycl::span<int, 8>{SpanMem, 8}, | ||
PlusWithoutIdentity<int>{}); | ||
|
||
// Shortcut and range<1> deduction from integer. | ||
Q.parallel_for(1024, ScalarRed1, [=](sycl::item<1>, auto &) {}); | ||
Q.parallel_for(1024, SpanRed1, [=](sycl::item<1>, auto &) {}); | ||
Q.parallel_for(1024, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
Q.parallel_for(1024, SpanRed1, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
Q.parallel_for(1024, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {}); | ||
|
||
// Shortcut and range<1> deduction from initializer. | ||
Q.parallel_for({1024}, ScalarRed1, [=](sycl::item<1>, auto &) {}); | ||
Q.parallel_for({1024}, SpanRed1, [=](sycl::item<1>, auto &) {}); | ||
Q.parallel_for({1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
Q.parallel_for({1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
Q.parallel_for({1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {}); | ||
|
||
// Shortcut and range<2> deduction from initializer. | ||
Q.parallel_for({1024, 1024}, ScalarRed1, [=](sycl::item<2>, auto &) {}); | ||
Q.parallel_for({1024, 1024}, SpanRed1, [=](sycl::item<2>, auto &) {}); | ||
Q.parallel_for({1024, 1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<2>, auto &, auto &) {}); | ||
Q.parallel_for({1024, 1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<2>, auto &, auto &) {}); | ||
Q.parallel_for({1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<2>, auto &, auto &, auto &, auto &) {}); | ||
|
||
// Shortcut and range<3> deduction from initializer. | ||
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, [=](sycl::item<3>, auto &) {}); | ||
Q.parallel_for({1024, 1024, 1024}, SpanRed1, [=](sycl::item<3>, auto &) {}); | ||
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<3>, auto &, auto &) {}); | ||
Q.parallel_for({1024, 1024, 1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<3>, auto &, auto &) {}); | ||
Q.parallel_for({1024, 1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<3>, auto &, auto &, auto &, auto &) {}); | ||
|
||
// Submission and range<1> deduction from integer. | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for(1024, ScalarRed1, [=](sycl::item<1>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for(1024, SpanRed1, [=](sycl::item<1>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for(1024, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for(1024, SpanRed1, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for(1024, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {}); | ||
}); | ||
|
||
// Submission and range<1> deduction from initializer. | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024}, ScalarRed1, [=](sycl::item<1>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024}, SpanRed1, [=](sycl::item<1>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<1>, auto &, auto &, auto &, auto &) {}); | ||
}); | ||
|
||
// Submission and range<2> deduction from initializer. | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024}, ScalarRed1, [=](sycl::item<2>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024}, SpanRed1, [=](sycl::item<2>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<2>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<2>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, SpanRed2, | ||
[=](sycl::item<2>, auto &, auto &, auto &, auto &) {}); | ||
}); | ||
|
||
// Submission and range<3> deduction from initializer. | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1, | ||
[=](sycl::item<3>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024, 1024}, SpanRed1, | ||
[=](sycl::item<3>, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1, ScalarRed2, | ||
[=](sycl::item<3>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024, 1024}, SpanRed1, SpanRed2, | ||
[=](sycl::item<3>, auto &, auto &) {}); | ||
}); | ||
Q.submit([&](sycl::handler &CGH) { | ||
CGH.parallel_for({1024, 1024, 1024}, ScalarRed1, SpanRed1, ScalarRed2, | ||
SpanRed2, | ||
[=](sycl::item<3>, auto &, auto &, auto &, auto &) {}); | ||
}); | ||
|
||
return 0; | ||
} |