Skip to content

Commit 3fd0850

Browse files
authored
[SYCL][ESIMD]Add invoke_simd support for functions returning void. (#6901)
1 parent 134618f commit 3fd0850

File tree

2 files changed

+40
-0
lines changed

2 files changed

+40
-0
lines changed

sycl/include/sycl/ext/oneapi/experimental/invoke_simd.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,8 @@ struct simd2spmd<T, std::enable_if_t<std::is_arithmetic_v<T>>> {
139139
using type = uniform<T>;
140140
};
141141

142+
template <> struct simd2spmd<void> { using type = void; };
143+
142144
// Determine number of elements in a simd type.
143145
template <class T> struct simd_size {
144146
static constexpr int value = 1; // 1 element in any type by default

sycl/test/invoke_simd/invoke_simd.cpp

+38
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ ESIMD_CALLEE(float *A, esimd::simd<float, VL> b, int i) SYCL_ESIMD_FUNCTION {
3838
[[intel::device_indirectly_callable]] SYCL_EXTERNAL
3939
simd<float, VL> __regcall SIMD_CALLEE(float *A, simd<float, VL> b,
4040
int i) SYCL_ESIMD_FUNCTION;
41+
[[intel::device_indirectly_callable]] SYCL_EXTERNAL
42+
void __regcall SIMD_CALLEE_VOID(simd<float, VL> b, int i) SYCL_ESIMD_FUNCTION {}
4143

4244
float SPMD_CALLEE(float *A, float b, int i) { return A[i] + b; }
4345

@@ -120,6 +122,7 @@ int main(void) {
120122
if constexpr (use_invoke_simd) {
121123
res = invoke_simd(sg, SIMD_CALLEE, uniform{A}, B[wi_id],
122124
uniform{i});
125+
invoke_simd(sg, SIMD_CALLEE_VOID, B[wi_id], uniform{i});
123126
} else {
124127
res = SPMD_CALLEE(A, B[wi_id], wi_id);
125128
}
@@ -183,6 +186,8 @@ struct SIMD_FUNCTOR {
183186
// E - N(u, N)
184187
SYCL_EXTERNAL __regcall simd<short, 8> operator()(simd<float, 3>,
185188
simd<int, 8>) const;
189+
// F - void
190+
SYCL_EXTERNAL __regcall void operator()(simd<float, 3>) const;
186191
};
187192

188193
// Functor-based tests.
@@ -208,6 +213,10 @@ SYCL_EXTERNAL void foo(sub_group sg, float a, float b, float *ptr) {
208213
// the target is "E" SIMD_FUNCTOR::() overload:
209214
auto v = invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
210215
static_assert(std::is_same_v<decltype(v), short>);
216+
217+
// the target is "F" SIMD_FUNCTOR::() overload:
218+
invoke_simd(sg, ftor, uniform{simd<float, 3>{1}});
219+
211220
}
212221

213222
// Lambda-based tests, repeat functor test cases above.
@@ -253,6 +262,29 @@ SYCL_EXTERNAL auto bar(sub_group sg, float a, float b, float *ptr, char ch) {
253262
auto v = invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
254263
static_assert(std::is_same_v<decltype(v), short>);
255264
}
265+
266+
{
267+
const auto ftor = [=] [[gnu::regcall]] (simd<float, 16>, float) {};
268+
invoke_simd(sg, ftor, 1.f, uniform{a});
269+
}
270+
{
271+
const auto ftor = [=] [[gnu::regcall]] (simd<float, 8>, float, int) {};
272+
invoke_simd(sg, ftor, b, uniform{1.f}, uniform{10});
273+
}
274+
{
275+
const auto ftor = [=] [[gnu::regcall]] (simd<float, 16>, float *) {};
276+
invoke_simd(sg, ftor, b, uniform{ptr});
277+
}
278+
{
279+
const auto ftor = [=] [[gnu::regcall]] (float *, simd<float, 3>,
280+
simd<int, 5>) {};
281+
invoke_simd(sg, ftor, uniform{ptr}, uniform{simd<float, 3>{1}},
282+
uniform{simd<int, 5>{2}});
283+
}
284+
{
285+
const auto ftor = [=] [[gnu::regcall]] (simd<float, 3>, simd<int, 8>) {};
286+
invoke_simd(sg, ftor, uniform{simd<float, 3>{1}}, 1);
287+
}
256288
}
257289

258290
// Function-pointer-based test
@@ -262,6 +294,11 @@ SYCL_EXTERNAL auto barx(sub_group sg, float a, char ch,
262294
static_assert(std::is_same_v<decltype(x), uniform<char>>);
263295
}
264296

297+
SYCL_EXTERNAL auto barx_void(sub_group sg, float a, char ch,
298+
__regcall void(f)(simd<float, 16>, float)) {
299+
invoke_simd(sg, f, 1.f, uniform{a});
300+
}
301+
265302
// Internal is_function_ref_v meta-API checks {
266303
template <class F> void assert_is_func(F &&f) {
267304
static_assert(
@@ -282,6 +319,7 @@ void check_f(
282319
int(func)(float), int(__regcall func_regcall)(int)) {
283320

284321
assert_is_func(SIMD_CALLEE);
322+
assert_is_func(SIMD_CALLEE_VOID);
285323
assert_is_func(ordinary_func);
286324

287325
assert_is_func(func_ptr);

0 commit comments

Comments
 (0)