Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion python/src/gluon_ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

using namespace mlir;
namespace py = pybind11;
namespace tt = triton;
namespace ttg = triton::gpu;
namespace ttng = triton::nvidia_gpu;

Expand Down Expand Up @@ -298,7 +299,15 @@ void init_gluon_ir(py::module &&m) {
self.create<ttng::AsyncTMAScatterOp>(descPtr, xOffsets, yOffset,
src);
})

.def("create_broadcast",
[](TritonOpBuilder &self, Value &arg, Type retTy) -> Value {
return self.create<tt::BroadcastOp>(retTy, arg);
})
.def(
"create_expand_dims",
[](TritonOpBuilder &self, Value &arg, int axis, Type retTy) -> Value {
return self.create<tt::ExpandDimsOp>(retTy, arg, axis);
})
.def("create_warp_return",
[](GluonOpBuilder &self) -> Operation * {
return self.create<ttg::WarpReturnOp>();
Expand Down
35 changes: 35 additions & 0 deletions python/test/gluon/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,3 +600,38 @@ def kernel():
}
}
""")


@gluon.jit
def broadcast_kernel():
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 1], [2, 16], [4, 1], [1, 0])
a = ttgl.arange(0, 16, layout=ttgl.SliceLayout(0, layout))[None, :]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😍

b = ttgl.arange(0, 16, layout=ttgl.SliceLayout(1, layout))[:, None]
0 + a + b


def test_broadcast(fresh_knobs):
knobs.compilation.disable_line_info = True

h = broadcast_kernel.warmup(sanitize_overflow=False, grid=(1, ))
expecttest.assert_expected_inline(
anonymize_ir(h.asm["source"]), """\
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [2, 16], warpsPerCTA = [4, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "...", "ttg.threads-per-warp" = 32 : i32} {
tt.func public @broadcast_kernel() attributes {noinline = false} {
%0 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> loc(#loc)
%1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 0, parent = #blocked}>> -> tensor<1x16xi32, #blocked> loc(#loc)
%2 = tt.make_range {end = 16 : i32, start = 0 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> loc(#loc)
%3 = tt.expand_dims %2 {axis = 1 : i32} : tensor<16xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<16x1xi32, #blocked> loc(#loc)
%c0_i32 = arith.constant 0 : i32 loc(#loc)
%c0_i32_0 = arith.constant 0 : i32 loc(#loc)
%cst = arith.constant dense<0> : tensor<1x16xi32, #blocked> loc(#loc)
%4 = arith.addi %cst, %1 : tensor<1x16xi32, #blocked> loc(#loc)
%5 = tt.broadcast %4 : tensor<1x16xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc)
%6 = tt.broadcast %3 : tensor<16x1xi32, #blocked> -> tensor<16x16xi32, #blocked> loc(#loc)
%7 = arith.addi %5, %6 : tensor<16x16xi32, #blocked> loc(#loc)
tt.return loc(#loc)
} loc(#loc)
} loc(#loc)
#loc = loc(unknown)
""")
1 change: 1 addition & 0 deletions python/triton/experimental/gluon/language/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
)

_IMPORT_FROM_TRITON: List[str] = [
"expand_dims", # NOQA: F822
"program_id", # NOQA: F822
"load", # NOQA: F822
"store", # NOQA: F822
Expand Down
88 changes: 85 additions & 3 deletions python/triton/experimental/gluon/language/_semantic.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from typing import Sequence
from typing import Sequence, List, TypeVar, Tuple, Callable
from triton.language.semantic import TritonSemantic
from . import _core as ttgl
from ._layouts import SliceLayout
from triton._C.libtriton.gluon_ir import GluonOpBuilder
from triton.compiler.code_generator import flatten_values_to_ir, unflatten_ir_values

TensorTy = TypeVar("TensorTy")

class GluonSemantic(TritonSemantic[ttgl.tensor]):

def _check(cond: bool, msg_fn: Callable[[], str], category=ValueError):
if not cond:
raise category(msg_fn())


class GluonSemantic(TritonSemantic[TensorTy]):
tensor = ttgl.tensor
lang = ttgl

Expand All @@ -14,6 +22,79 @@ class GluonSemantic(TritonSemantic[ttgl.tensor]):
def __init__(self, builder: GluonOpBuilder):
self.builder = builder

def _broadcast_shapes(self, lhs_shape: List[int], rhs_shape: List[int]):
if len(lhs_shape) != len(rhs_shape):
raise ValueError(f"Cannot broadcast, rank mismatch: {lhs_shape}, {rhs_shape}")

ret_shape = []
for i, left in enumerate(lhs_shape):
right = rhs_shape[i]
if left == 1:
ret_shape.append(right)
elif (right == 1) or (right == left):
ret_shape.append(left)
else:
raise ValueError("Cannot make_shape_compatible: incompatible dimensions "
"at index " + str(i) + ": " + str(left) + " and " + str(right))
return ret_shape

def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
dst_shape = [ttgl._unwrap_if_constexpr(x) for x in input.shape]
dst_shape.insert(axis, 1)

if axis < 0:
axis += len(input.shape)

_check(isinstance(input.type, ttgl.distributed_type),
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
layout = input.type.layout
_check(isinstance(layout, SliceLayout),
lambda: f"expected expand_dims input to have a SliceLayout, but got: {layout}")
_check(layout.dim == axis,
lambda: f"expected expand_dims input layout to be sliced in axis {axis} but got {layout.dim}")

ret_ty = ttgl.distributed_type(input.type.scalar, dst_shape, layout.parent)
handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
return self.tensor(handle, ret_ty)

def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
_check(isinstance(input.type, ttgl.distributed_type),
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")
src_shape = input.type.get_block_shapes()
_check(len(src_shape) == len(shape), lambda: f"Cannot broadcast, rank mismatch: {src_shape}, {shape}")
if shape == src_shape:
return input
for i, item in enumerate(src_shape):
if shape[i] != item and item != 1:
raise ValueError(f"Cannot broadcast, the expanded size of the tensor ({shape[i]})"
f" must match the existing size ({item}) at non-singleton dimension"
f" {i}: {src_shape}, {shape}")
ret_ty = ttgl.distributed_type(input.type.scalar, shape, input.type.layout)
handle = self.builder.create_broadcast(input.handle, ret_ty.to_ir(self.builder))
return self.tensor(handle, ret_ty)

def broadcast_impl_value(self, lhs: TensorTy, rhs: TensorTy) -> TensorTy:
lhs_ty = lhs.type
rhs_ty = rhs.type

if not lhs_ty.is_block() or not rhs_ty.is_block():
return super().broadcast_impl_value(lhs, rhs)

_check(isinstance(lhs_ty, ttgl.distributed_type),
lambda: f"expected broadcast left input to be a distributed_type but got: {lhs_ty!r}")
_check(isinstance(rhs_ty, ttgl.distributed_type),
lambda: f"expected broadcast right input to be a distributed_type but got: {rhs_ty!r}")

lhs_shape = lhs_ty.get_block_shapes()
rhs_shape = rhs_ty.get_block_shapes()
ret_shape = self._broadcast_shapes(lhs_shape, rhs_shape)
if lhs_ty.layout != rhs_ty.layout:
raise ValueError(f"Layout mismatch in broadcast: {lhs_ty.layout} vs {rhs_ty.layout}")

lhs = self.broadcast_impl_shape(lhs, ret_shape)
rhs = self.broadcast_impl_shape(rhs, ret_shape)
return lhs, rhs

def arange(self, start, end, layout):
shape = [end - start]
ret_ty = ttgl.distributed_type(ttgl.int32, shape, layout)
Expand All @@ -30,7 +111,8 @@ def full(self, shape, value, dtype, layout):

def convert_layout(self, value, layout):
ty = value.type
assert isinstance(ty, ttgl.distributed_type)
_check(isinstance(ty, ttgl.distributed_type),
lambda: f"expected convert_layout input to be a distributed_type but got: {ty!r}")
ret_ty = ttgl.distributed_type(ty.element_ty, ty.shape, layout)
handle = self.builder.create_convert_layout(ret_ty.to_ir(self.builder), value.handle)
return ttgl.tensor(handle, ret_ty)
Expand Down
3 changes: 1 addition & 2 deletions python/triton/language/semantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1055,8 +1055,7 @@ def _load_legacy(self, ptr, mask, other, boundary_check, padding, cache, evictio

# Create loaded result type `dst_ty`
if ptr.type.is_block():
shape = ptr.type.get_block_shapes()
dst_ty = tl.block_type(elt_ty, shape)
dst_ty = ptr.type.with_element_ty(elt_ty)
else:
# Load by de-referencing the pointer of scalar
dst_ty = elt_ty
Expand Down
Loading