diff --git a/python/test/gluon/test_frontend.py b/python/test/gluon/test_frontend.py index 9e9ab4690252..10a4204b5df0 100644 --- a/python/test/gluon/test_frontend.py +++ b/python/test/gluon/test_frontend.py @@ -11,6 +11,7 @@ from triton.experimental.gluon.nvidia.hopper import TensorDescriptor from triton.experimental.gluon.language.amd import _layouts as amd_layouts from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy +from triton.experimental.gluon.language.amd.gfx1250 import async_copy as gfx1250_async_copy from triton.experimental.gluon.language.extra import libdevice from triton._filecheck import filecheck_test, run_parser @@ -1949,6 +1950,87 @@ def test_infer_layout_for_amd_wmma(target): """) +@gluon.jit +def amd_async_copy_global_to_shared(ptr): + blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 1], [4, 1], [1, 0]) + shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0]) + + smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared) + y_offset = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked)) + x_offset = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked)) + offsets = y_offset[:, None] * 16 + x_offset[None, :] + + # test default parameters + gfx1250_async_copy.global_to_shared(smem, ptr + offsets) + + # test mask + mask = (y_offset < 64)[:, None] + gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask) + + # Test other with scalar + gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask, other=0.0) + + # Test other with tensor + other = ttgl.full([128, 16], 0.0, ptr.dtype.element_ty, layout=blocked) + gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask, other) + + gfx1250_async_copy.commit_group() + + +@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250]) +def test_amd_async_copy_global_to_shared(target): + ptr = MockTensor(ttgl.float16) + mod = run_parser(amd_async_copy_global_to_shared, *make_args(ptr), target=target) + expecttest.assert_expected_inline( + anonymize_ir(mod.str_nodebug()), """\ +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}> +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> +#smem = #ttg.shared_memory +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} { + tt.func public @amd_async_copy_global_to_shared(%arg0: !tt.ptr {tt.divisibility = 16 : i32}) attributes {noinline = false} { + %0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> + %3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked> + %c16_i32 = arith.constant 16 : i32 + %c16_i32_0 = arith.constant 16 : i32 + %cst = arith.constant dense<16> : tensor<128x1xi32, #blocked> + %4 = arith.muli %3, %cst : tensor<128x1xi32, #blocked> + %5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> + %6 = tt.broadcast %4 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked> + %7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked> + %8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked> + %9 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %10 = tt.addptr %9, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %11 = ttg.async_copy_global_to_local %10, %0 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %c64_i32 = arith.constant 64 : i32 + %cst_1 = arith.constant dense<64> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %12 = arith.cmpi slt, %1, %cst_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> + %13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked> + %14 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %15 = tt.addptr %14, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %16 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked> + %17 = ttg.async_copy_global_to_local %15, %0 mask %16 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %18 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %19 = tt.addptr %18, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %20 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked> + %cst_2 = arith.constant 0.000000e+00 : f32 + %21 = arith.truncf %cst_2 : f32 to f16 + %22 = tt.splat %21 : f16 -> tensor<128x16xf16, #blocked> + %23 = ttg.async_copy_global_to_local %19, %0 mask %20 other %22 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %cst_3 = arith.constant 0.000000e+00 : f16 + %cst_4 = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked> + %24 = tt.splat %arg0 : !tt.ptr -> tensor<128x16x!tt.ptr, #blocked> + %25 = tt.addptr %24, %8 : tensor<128x16x!tt.ptr, #blocked>, tensor<128x16xi32, #blocked> + %26 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked> + %27 = ttg.async_copy_global_to_local %25, %0 mask %26 other %cst_4 : tensor<128x16x!tt.ptr, #blocked> -> <128x16xf16, #shared, #smem, mutable> + %28 = ttg.async_commit_group + tt.return + } +} +""") + + @gluon.jit def amd_commit_group(): cdna4_async_copy.commit_group() diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py index 2492eefb783e..c30111f1e526 100644 --- a/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py +++ b/python/triton/experimental/gluon/language/amd/gfx1250/__init__.py @@ -3,8 +3,9 @@ from .._layouts import AMDWMMALayout from ..cdna3 import buffer_load, buffer_store from . import tdm +from . import async_copy -__all__ = ["tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"] +__all__ = ["async_copy", "tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"] def _get_wmma_scale_layout(dot_operand_layout, shape, semantic): diff --git a/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py b/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py new file mode 100644 index 000000000000..cae916c69f02 --- /dev/null +++ b/python/triton/experimental/gluon/language/amd/gfx1250/async_copy.py @@ -0,0 +1,45 @@ +from ..._core import ir, builtin, _unwrap_if_constexpr +from ..._semantic import _check +from triton.experimental.gluon.language._layouts import DistributedLayout +from ..cdna4.async_copy import commit_group, wait_group + +__all__ = [ + "global_to_shared", + "commit_group", + "wait_group", +] + + +@builtin +def global_to_shared(smem, pointer, mask=None, other=None, cache_modifier="", _semantic=None): + """ + Asynchronously copy elements from global memory to shared memory. Requires manual syncronization via `wait_group` before accessing the loaded data. + + Args: + smem (shared_memory_descriptor): Destination shared memory descriptor. + pointer (tensor): Source pointer tensor. + mask (tensor, optional): Mask tensor for predicated loads. Defaults to None. + other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None(0). + cache_modifier (str): Cache modifier specifier. Defaults to "". + eviction_policy (str): Eviction policy specifier. Defaults to "". + """ + _check(pointer.type.is_block(), lambda: "expected ptr to be a tensor") + _check(isinstance(pointer.type.layout, DistributedLayout), + lambda: "expected ptr type layout to be BlockedLayout or SliceLayout") + _check( + smem.shape == pointer.shape, lambda: + f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}" + ) + mask = _unwrap_if_constexpr(mask) + if mask is not None: + pointer, mask = _semantic.broadcast_impl_value(pointer, mask) + other = _unwrap_if_constexpr(other) + if other is not None: + other = _semantic.to_tensor(other) + other = _semantic.cast(other, pointer.dtype.element_ty) + pointer, other = _semantic.broadcast_impl_value(pointer, other) + cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier) + mask_handle = mask.handle if mask is not None else ir.value() + other_handle = other.handle if other is not None else ir.value() + _semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, other_handle, + cache_modifier, ir.EVICTION_POLICY.NORMAL, False)