Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from triton.experimental.gluon.nvidia.hopper import TensorDescriptor
from triton.experimental.gluon.language.amd import _layouts as amd_layouts
from triton.experimental.gluon.language.amd.cdna4 import async_copy as cdna4_async_copy
from triton.experimental.gluon.language.amd.gfx1250 import async_copy as gfx1250_async_copy
from triton.experimental.gluon.language.extra import libdevice

from triton._filecheck import filecheck_test, run_parser
Expand Down Expand Up @@ -1949,6 +1950,87 @@ def test_infer_layout_for_amd_wmma(target):
""")


@gluon.jit
def amd_async_copy_global_to_shared(ptr):
blocked: ttgl.constexpr = ttgl.BlockedLayout([1, 8], [32, 1], [4, 1], [1, 0])
shared: ttgl.constexpr = ttgl.SwizzledSharedLayout(1, 1, 1, order=[1, 0])

smem = ttgl.allocate_shared_memory(ptr.dtype.element_ty, [128, 16], shared)
y_offset = ttgl.arange(0, 128, layout=ttgl.SliceLayout(1, blocked))
x_offset = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, blocked))
offsets = y_offset[:, None] * 16 + x_offset[None, :]

# test default parameters
gfx1250_async_copy.global_to_shared(smem, ptr + offsets)

# test mask
mask = (y_offset < 64)[:, None]
gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask)

# Test other with scalar
gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask, other=0.0)

# Test other with tensor
other = ttgl.full([128, 16], 0.0, ptr.dtype.element_ty, layout=blocked)
gfx1250_async_copy.global_to_shared(smem, ptr + offsets, mask, other)

gfx1250_async_copy.commit_group()


@pytest.mark.parametrize("target", [HIP_TARGET_GFX1250])
def test_amd_async_copy_global_to_shared(target):
ptr = MockTensor(ttgl.float16)
mod = run_parser(amd_async_copy_global_to_shared, *make_args(ptr), target=target)
expecttest.assert_expected_inline(
anonymize_ir(mod.str_nodebug()), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [32, 1], warpsPerCTA = [4, 1], order = [1, 0]}>
#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}>
#smem = #ttg.shared_memory
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @amd_async_copy_global_to_shared(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%0 = ttg.local_alloc : () -> !ttg.memdesc<128x16xf16, #shared, #smem, mutable>
%1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>>
%3 = tt.expand_dims %1 {axis = 1 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi32, #blocked>
%c16_i32 = arith.constant 16 : i32
%c16_i32_0 = arith.constant 16 : i32
%cst = arith.constant dense<16> : tensor<128x1xi32, #blocked>
%4 = arith.muli %3, %cst : tensor<128x1xi32, #blocked>
%5 = tt.expand_dims %2 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked>
%6 = tt.broadcast %4 : tensor<128x1xi32, #blocked> -> tensor<128x16xi32, #blocked>
%7 = tt.broadcast %5 : tensor<1x16xi32, #blocked> -> tensor<128x16xi32, #blocked>
%8 = arith.addi %6, %7 : tensor<128x16xi32, #blocked>
%9 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%10 = tt.addptr %9, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%11 = ttg.async_copy_global_to_local %10, %0 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
%c64_i32 = arith.constant 64 : i32
%cst_1 = arith.constant dense<64> : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%12 = arith.cmpi slt, %1, %cst_1 : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked}>>
%13 = tt.expand_dims %12 {axis = 1 : i32} : tensor<128xi1, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<128x1xi1, #blocked>
%14 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%15 = tt.addptr %14, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%16 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
%17 = ttg.async_copy_global_to_local %15, %0 mask %16 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
%18 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%19 = tt.addptr %18, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%20 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
%cst_2 = arith.constant 0.000000e+00 : f32
%21 = arith.truncf %cst_2 : f32 to f16
%22 = tt.splat %21 : f16 -> tensor<128x16xf16, #blocked>
%23 = ttg.async_copy_global_to_local %19, %0 mask %20 other %22 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
%cst_3 = arith.constant 0.000000e+00 : f16
%cst_4 = arith.constant dense<0.000000e+00> : tensor<128x16xf16, #blocked>
%24 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<128x16x!tt.ptr<f16>, #blocked>
%25 = tt.addptr %24, %8 : tensor<128x16x!tt.ptr<f16>, #blocked>, tensor<128x16xi32, #blocked>
%26 = tt.broadcast %13 : tensor<128x1xi1, #blocked> -> tensor<128x16xi1, #blocked>
%27 = ttg.async_copy_global_to_local %25, %0 mask %26 other %cst_4 : tensor<128x16x!tt.ptr<f16>, #blocked> -> <128x16xf16, #shared, #smem, mutable>
%28 = ttg.async_commit_group
tt.return
}
}
""")


@gluon.jit
def amd_commit_group():
cdna4_async_copy.commit_group()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from .._layouts import AMDWMMALayout
from ..cdna3 import buffer_load, buffer_store
from . import tdm
from . import async_copy

__all__ = ["tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"]
__all__ = ["async_copy", "tdm", "wmma", "wmma_scaled", "buffer_load", "buffer_store", "get_wmma_scale_layout"]


def _get_wmma_scale_layout(dot_operand_layout, shape, semantic):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from ..._core import ir, builtin, _unwrap_if_constexpr
from ..._semantic import _check
from triton.experimental.gluon.language._layouts import DistributedLayout
from ..cdna4.async_copy import commit_group, wait_group

__all__ = [
"global_to_shared",
"commit_group",
"wait_group",
]


@builtin
def global_to_shared(smem, pointer, mask=None, other=None, cache_modifier="", _semantic=None):
"""
Asynchronously copy elements from global memory to shared memory. Requires manual syncronization via `wait_group` before accessing the loaded data.

Args:
smem (shared_memory_descriptor): Destination shared memory descriptor.
pointer (tensor): Source pointer tensor.
mask (tensor, optional): Mask tensor for predicated loads. Defaults to None.
other (tensor or scalar, optional): Tensor or scalar providing default values for masked elements. Defaults to None(0).
cache_modifier (str): Cache modifier specifier. Defaults to "".
eviction_policy (str): Eviction policy specifier. Defaults to "".
"""
_check(pointer.type.is_block(), lambda: "expected ptr to be a tensor")
_check(isinstance(pointer.type.layout, DistributedLayout),
lambda: "expected ptr type layout to be BlockedLayout or SliceLayout")
_check(
smem.shape == pointer.shape, lambda:
f"expected smem shape to match pointer shape but got smem.shape = {smem.shape}, pointer.shape = {pointer.shape}"
)
mask = _unwrap_if_constexpr(mask)
if mask is not None:
pointer, mask = _semantic.broadcast_impl_value(pointer, mask)
other = _unwrap_if_constexpr(other)
if other is not None:
other = _semantic.to_tensor(other)
other = _semantic.cast(other, pointer.dtype.element_ty)
pointer, other = _semantic.broadcast_impl_value(pointer, other)
cache_modifier = _semantic._str_to_load_cache_modifier(cache_modifier)
mask_handle = mask.handle if mask is not None else ir.value()
other_handle = other.handle if other is not None else ir.value()
_semantic.builder.create_async_copy_global_to_local(smem.handle, pointer.handle, mask_handle, other_handle,
cache_modifier, ir.EVICTION_POLICY.NORMAL, False)
Loading