Skip to content

Commit 1e7a8ea

Browse files
authored
[ESIMDS] Support tfloat32 types in dpas() (#6948)
Signed-off-by: Vyacheslav N Klochkov <[email protected]>
1 parent 9cf7451 commit 1e7a8ea

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

sycl/include/sycl/ext/intel/esimd/xmx/dpas.hpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,10 @@ namespace ext::intel::esimd::xmx {
2424
namespace detail {
2525

2626
template <typename T> constexpr dpas_argument_type dpas_precision_from_type() {
27-
// TODO: add support for tfloat32 here.
28-
if constexpr (std::is_same_v<T, sycl::half>)
27+
if constexpr (std::is_same_v<T,
28+
sycl::ext::intel::experimental::esimd::tfloat32>)
29+
return dpas_argument_type::tf32;
30+
else if constexpr (std::is_same_v<T, sycl::half>)
2931
return dpas_argument_type::fp16;
3032
else if constexpr (std::is_same_v<T,
3133
sycl::ext::oneapi::experimental::bfloat16>)

sycl/test/esimd/dpas.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() {
207207
constexpr int K_half = 8 * 2;
208208
constexpr int K_bf16 = 8 * 2;
209209
constexpr int K_int8x2 = 8 * 4;
210+
constexpr int K_tf32 = 8 * 1;
210211
constexpr int N_pvc = 16;
211212
constexpr int N_dg2 = 8;
212213

@@ -338,6 +339,26 @@ SYCL_ESIMD_FUNCTION SYCL_EXTERNAL void xmx_func() {
338339
// CHECK: call <8 x float> @llvm.genx.dpasw.nosrc0.v8f32.v64i32.v4i32(<64 x i32> {{[^,]+}}, <4 x i32> {{[^,]+}}, i32 17304074)
339340
}
340341

342+
{ // ======= DPAS TFLOAT32 ===================================================
343+
simd<float, M_one *N_pvc> R_f = 0;
344+
simd<float, M_one *N_pvc> C_f = 0;
345+
346+
simd<sycl::ext::intel::experimental::esimd::tfloat32, K_tf32 *N_pvc> B_tf =
347+
0;
348+
simd<sycl::ext::intel::experimental::esimd::tfloat32, M_one *K_tf32> A_tf =
349+
0;
350+
351+
// ------------------- TFLOAT32: WITH ACC OPERAND --------------------------
352+
R_f = xmx::dpas<8, 1, float>(C_f, B_tf, A_tf);
353+
zoo(R_f);
354+
// CHECK: call <16 x float> @llvm.genx.dpas2.v16f32.v16f32.v128i32.v8i32(<16 x float> {{[^,]+}}, <128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 12, i32 12, i32 8, i32 1, i32 1, i32 1)
355+
356+
// ------------------- TFLOAT32: NO ACC OPERAND ----------------------------
357+
R_f = xmx::dpas<8, 1, float>(B_tf, A_tf);
358+
zoo(R_f);
359+
// CHECK: call <16 x float> @llvm.genx.dpas.nosrc0.v16f32.v128i32.v8i32(<128 x i32> {{[^,]+}}, <8 x i32> {{[^,]+}}, i32 17304588)
360+
}
361+
341362
xmx_func_end();
342363
// CHECK: call spir_func void @_Z12xmx_func_endv()
343364
}

0 commit comments

Comments
 (0)