Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 31 additions & 37 deletions python/test/unit/language/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import os
import io
import inspect
from typing import List, Tuple

from filecheck.options import Options
from filecheck.finput import FInput
Expand All @@ -14,7 +13,6 @@
from triton.compiler import ASTSource, make_backend
from triton.backends.compiler import GPUTarget
from triton._C.libtriton import ir
from triton.language.core import base_type, base_value

import pytest

Expand Down Expand Up @@ -113,38 +111,14 @@ def test_fn():
# ===-----------------------------------------------------------------------===#


class pair_type(base_type):

def __init__(self, first_type, second_type):
self.first_type = first_type
self.second_type = second_type

def __eq__(self, other) -> bool:
return self.first_type == other.first_type and self.second_type == other.second_type

def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
first, cursor = self.first_type._unflatten_ir(handles, cursor)
second, cursor = self.second_type._unflatten_ir(handles, cursor)
return pair_value(first, second), cursor

def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
self.first_type._flatten_ir_types(builder, out)
self.second_type._flatten_ir_types(builder, out)

def mangle(self) -> str:
return f"pair<{self.first_type.mangle()}, {self.second_type.mangle()}>"


class pair_value(base_value):
@tl.core._aggregate
class Pair:
first: tl.tensor
second: tl.tensor

def __init__(self, first, second):
self.first = first
self.second = second
self.type = pair_type(first.type, second.type)

def _flatten_ir(self, handles: List[ir.value]) -> None:
self.first._flatten_ir(handles)
self.second._flatten_ir(handles)

@triton.jit
def get_first(self):
Expand All @@ -158,19 +132,14 @@ def unpack(self):
return self.get_first(), self.get_second()


@tl.core.builtin
def pair_value_ctor(first, second, _builder=None):
return pair_value(first, second)


@filecheck_test
@triton.jit
def test_assign_attribute():
# CHECK-LABEL: assign_attribute
# CHECK: %c11_i32 = arith.constant 11 : i32
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
scalar = 11
pair = pair_value_ctor(tl.arange(0, 4), scalar)
pair = Pair(tl.arange(0, 4), scalar)
# CHECK: %c42_i32 = arith.constant 42 : i32
# CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32)
pair.second = 42
Expand All @@ -185,9 +154,34 @@ def test_jit_method():
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
scalar = 11
# CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32)
pair = pair_value_ctor(tl.arange(0, 4), scalar)
pair = Pair(tl.arange(0, 4), scalar)
a, b = pair.unpack()
# CHECK: call @anchor{{.*}}([[V]]#0)
anchor(a)
# CHECK: call @anchor{{.*}}([[V]]#1)
anchor(b)


@tl.core._aggregate
class TypeWithBuiltinInitializer:
value: tl.tensor

def __init__(self, _builder=None):
self.value = tl.arange(0, 4, _builder=_builder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so neat. I think before we can make this public we need the methods to automatically be treated as @jit functions and not have to manually pass the builder around.

As is, this seems reasonable as an internal tool though. Maybe just rename to private (_aggregate).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to rename it for now. The problem with making these @jit functions is that @jit arguments are not mutable through to the caller. I.e. doing self.first = 4 inside a JITFunction initializer won't actually change anything on the caller side.

We would need a more involved fix that makes all @jit methods automatically return self and patch the callsite to be x, other_results = x.method(*args)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also another trick somewhere that makes member methods automatically @builtin, so we'd need to flip the convention for these types but that's not a big issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _aggregate. I'll ponder the self mutability problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better. Would it be possible to use class without any annotations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the semantics, maybe. The frontend would have to infer the annotations from the attributes set by the initializer


def modify(self, value, _builder=None):
self.value = value


@filecheck_test
@triton.jit
def test_aggregate_initializers():
# CHECK-LABEL: test_aggregate_initializers
value = TypeWithBuiltinInitializer()
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32}
# CHECK: call @"anchor{{.*}}"([[RANGE]])
anchor(value)
# CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32}
# CHECK: call @"anchor{{.*}}"([[RANGE]])
value.modify(tl.arange(4, 8))
anchor(value)
1 change: 1 addition & 0 deletions python/triton/compiler/code_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def global_lookup(name: str, absent):
type(val) is ModuleType, #
isinstance(val, JITFunction), #
getattr(val, "__triton_builtin__", False), #
getattr(val, "__triton_aggregate__", False), #
getattr(val, "__module__", "").startswith("triton.language"), #
isinstance(val, language.dtype), #
_is_namedtuple(val),
Expand Down
95 changes: 94 additions & 1 deletion python/triton/language/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@
from functools import partial, wraps
import typing
from typing import Union, Callable, List, Sequence, TypeVar, Optional, Tuple
from dataclasses import dataclass
import builtins
from .. import knobs
from ..runtime.jit import jit
from ..runtime.jit import jit, JITFunction
import inspect

from .._C.libtriton import ir
Expand Down Expand Up @@ -1487,6 +1488,98 @@ def _flatten_ir(self, handles: List[ir.value]) -> None:
handles.extend(s.handle for s in self.strides)


# -----------------------
# aggregate
# -----------------------


@dataclass(frozen=True)
class _aggregate_type(base_type):
"""A generic base type for all Triton aggregate types.

This class contains a reference to the original user-defined Python class
and a list of class fields with their Triton types.
"""

base_cls: type
fields: List[Tuple[str, base_type]]

def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]:
instance = self.base_cls._get_instance()
for name, ty in self.fields:
value, cursor = ty._unflatten_ir(handles, cursor)
setattr(instance, name, value)
return instance, cursor

def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
for name, ty in self.fields:
ty._flatten_ir_types(builder, out)

def mangle(self) -> str:
name = f"{self.base_cls.__module__}.{self.base_cls.__qualname__}"
fields = [ty.mangle() for (name, ty) in self.fields]
return f"{name}<{', '.join(fields)}>"


def _aggregate(cls):

# Define the wrapped Triton value type.
class aggregate_value(base_value):
__triton_builtin__ = True
__triton_aggregate__ = True

@classmethod
def _get_instance(this_cls):
return super().__new__(this_cls)

def __new__(this_cls, *args, _builder=None, _generator=None, **kwargs):
# Call into the user-defined constructor.
instance = this_cls._get_instance()
if isinstance(cls.__init__, JITFunction):
raise ValueError(f"{cls.__name__}.__init__ cannot be a @triton.jit function")
extra_kwargs = {}
if "_builder" in inspect.signature(cls.__init__).parameters:
extra_kwargs["_builder"] = _builder
if "_generator" in inspect.signature(cls.__init__).parameters:
extra_kwargs["_generator"] = _generator
cls.__init__(instance, *args, **extra_kwargs, **kwargs)

# Require that the user-defined constructor initialized all fields.
for name in cls.__annotations__.keys():
if not hasattr(instance, name):
raise AttributeError(f"constructor for {cls.__name__} did not initialize attribute '{name}'")

return instance

# Only allow setting attributes defined in the class annotations.
def __setattr__(self, name, value):
if name not in cls.__annotations__:
raise AttributeError(f"{cls.__name__} has no attribute '{name}'")
if not isinstance(value, cls.__annotations__[name]):
raise TypeError(f"Expected {cls.__annotations__[name]} for attribute '{name}', got {type(value)}")
super().__setattr__(name, value)

def _flatten_ir(self, handles: List[ir.value]) -> None:
for name in cls.__annotations__.keys():
getattr(self, name)._flatten_ir(handles)

@property
def type(self):
return _aggregate_type(aggregate_value,
[(name, getattr(self, name).type) for name in cls.__annotations__.keys()])

for (name, member) in inspect.getmembers(cls):
if inspect.isfunction(member) or inspect.ismethod(member) or isinstance(member, JITFunction):
if name != "__init__":
setattr(aggregate_value, name, member)

aggregate_value.__name__ = cls.__name__
aggregate_value.__module__ = cls.__module__
aggregate_value.__qualname__ = cls.__qualname__

return aggregate_value


# -----------------------
# SPMD Programming Model
# -----------------------
Expand Down
Loading