Skip to content

Commit

Permalink
[SYCL][ESIMD]Add invoke_simd support for functions returning void. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
fineg74 authored Oct 6, 2022
1 parent 134618f commit 3fd0850
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
2 changes: 2 additions & 0 deletions sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,8 @@ struct simd2spmd<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
using type = uniform<T>;
};

template <> struct simd2spmd<void> { using type = void; };

// Determine number of elements in a simd type.
template <class T> struct simd_size {
static constexpr int value = 1; // 1 element in any type by default
Expand Down
38 changes: 38 additions & 0 deletions sycl/test/invoke_simd/invoke_simd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ ESIMD_CALLEE(float *A, esimd::simd<float, VL> b, int i) SYCL_ESIMD_FUNCTION {
[[intel::device_indirectly_callable]] SYCL_EXTERNAL
simd<float, VL> __regcall SIMD_CALLEE(float *A, simd<float, VL> b,
int i) SYCL_ESIMD_FUNCTION;
[[intel::device_indirectly_callable]] SYCL_EXTERNAL
void __regcall SIMD_CALLEE_VOID(simd<float, VL> b, int i) SYCL_ESIMD_FUNCTION {}

float SPMD_CALLEE(float *A, float b, int i) { return A[i] + b; }

Expand Down Expand Up @@ -120,6 +122,7 @@ int main(void) {
if constexpr (use_invoke_simd) {
res = invoke_simd(sg, SIMD_CALLEE, uniform{A}, B[wi_id],
uniform{i});
invoke_simd(sg, SIMD_CALLEE_VOID, B[wi_id], uniform{i});
} else {
res = SPMD_CALLEE(A, B[wi_id], wi_id);
}
Expand Down Expand Up @@ -183,6 +186,8 @@ struct SIMD_FUNCTOR {
// E - N(u, N)
SYCL_EXTERNAL __regcall simd<short, 8> operator()(simd<float, 3>,
simd<int, 8>) const;
// F - void
SYCL_EXTERNAL __regcall void operator()(simd<float, 3>) const;
};

// Functor-based tests.
Expand All @@ -208,6 +213,10 @@ SYCL_EXTERNAL void foo(sub_group sg, float a, float b, float *ptr) {
// the target is "E" SIMD_FUNCTOR::() overload:
auto v = invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
static_assert(std::is_same_v<decltype(v), short>);

// the target is "F" SIMD_FUNCTOR::() overload:
invoke_simd(sg, ftor, uniform{simd<float, 3>{1}});

}

// Lambda-based tests, repeat functor test cases above.
Expand Down Expand Up @@ -253,6 +262,29 @@ SYCL_EXTERNAL auto bar(sub_group sg, float a, float b, float *ptr, char ch) {
auto v = invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
static_assert(std::is_same_v<decltype(v), short>);
}

{
const auto ftor = [=] [[gnu::regcall]] (simd<float, 16>, float) {};
invoke_simd(sg, ftor, 1.f, uniform{a});
}
{
const auto ftor = [=] [[gnu::regcall]] (simd<float, 8>, float, int) {};
invoke_simd(sg, ftor, b, uniform{1.f}, uniform{10});
}
{
const auto ftor = [=] [[gnu::regcall]] (simd<float, 16>, float *) {};
invoke_simd(sg, ftor, b, uniform{ptr});
}
{
const auto ftor = [=] [[gnu::regcall]] (float *, simd<float, 3>,
simd<int, 5>) {};
invoke_simd(sg, ftor, uniform{ptr}, uniform{simd<float, 3>{1}},
uniform{simd<int, 5>{2}});
}
{
const auto ftor = [=] [[gnu::regcall]] (simd<float, 3>, simd<int, 8>) {};
invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
}
}

// Function-pointer-based test
Expand All @@ -262,6 +294,11 @@ SYCL_EXTERNAL auto barx(sub_group sg, float a, char ch,
static_assert(std::is_same_v<decltype(x), uniform<char>>);
}

SYCL_EXTERNAL auto barx_void(sub_group sg, float a, char ch,
__regcall void(f)(simd<float, 16>, float)) {
invoke_simd(sg, f, 1.f, uniform{a});
}

// Internal is_function_ref_v meta-API checks {
template <class F> void assert_is_func(F &&f) {
static_assert(
Expand All @@ -282,6 +319,7 @@ void check_f(
int(func)(float), int(__regcall func_regcall)(int)) {

assert_is_func(SIMD_CALLEE);
assert_is_func(SIMD_CALLEE_VOID);
assert_is_func(ordinary_func);

assert_is_func(func_ptr);
Expand Down

0 comments on commit 3fd0850

Please sign in to comment.