-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[Frontend] Add @tl.aggregate which autogenerates a Triton type
#6970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
b252a6e
c8b1942
1c4d5a0
3ce1eb6
aa8edda
e58f4fa
5cb4908
1225046
ec5c926
63cb6a6
3dc74bd
493246d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,6 @@ | |
| import os | ||
| import io | ||
| import inspect | ||
| from typing import List, Tuple | ||
|
|
||
| from filecheck.options import Options | ||
| from filecheck.finput import FInput | ||
|
|
@@ -14,7 +13,6 @@ | |
| from triton.compiler import ASTSource, make_backend | ||
| from triton.backends.compiler import GPUTarget | ||
| from triton._C.libtriton import ir | ||
| from triton.language.core import base_type, base_value | ||
|
|
||
| import pytest | ||
|
|
||
|
|
@@ -113,38 +111,14 @@ def test_fn(): | |
| # ===-----------------------------------------------------------------------===# | ||
|
|
||
|
|
||
| class pair_type(base_type): | ||
|
|
||
| def __init__(self, first_type, second_type): | ||
| self.first_type = first_type | ||
| self.second_type = second_type | ||
|
|
||
| def __eq__(self, other) -> bool: | ||
| return self.first_type == other.first_type and self.second_type == other.second_type | ||
|
|
||
| def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[ir.value, int]: | ||
| first, cursor = self.first_type._unflatten_ir(handles, cursor) | ||
| second, cursor = self.second_type._unflatten_ir(handles, cursor) | ||
| return pair_value(first, second), cursor | ||
|
|
||
| def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None: | ||
| self.first_type._flatten_ir_types(builder, out) | ||
| self.second_type._flatten_ir_types(builder, out) | ||
|
|
||
| def mangle(self) -> str: | ||
| return f"pair<{self.first_type.mangle()}, {self.second_type.mangle()}>" | ||
|
|
||
|
|
||
| class pair_value(base_value): | ||
| @tl.core._aggregate | ||
| class Pair: | ||
| first: tl.tensor | ||
| second: tl.tensor | ||
|
|
||
| def __init__(self, first, second): | ||
| self.first = first | ||
| self.second = second | ||
| self.type = pair_type(first.type, second.type) | ||
|
|
||
| def _flatten_ir(self, handles: List[ir.value]) -> None: | ||
| self.first._flatten_ir(handles) | ||
| self.second._flatten_ir(handles) | ||
|
|
||
| @triton.jit | ||
| def get_first(self): | ||
|
|
@@ -158,19 +132,14 @@ def unpack(self): | |
| return self.get_first(), self.get_second() | ||
|
|
||
|
|
||
| @tl.core.builtin | ||
| def pair_value_ctor(first, second, _builder=None): | ||
| return pair_value(first, second) | ||
|
|
||
|
|
||
| @filecheck_test | ||
| @triton.jit | ||
| def test_assign_attribute(): | ||
| # CHECK-LABEL: assign_attribute | ||
| # CHECK: %c11_i32 = arith.constant 11 : i32 | ||
| # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} | ||
| scalar = 11 | ||
| pair = pair_value_ctor(tl.arange(0, 4), scalar) | ||
| pair = Pair(tl.arange(0, 4), scalar) | ||
| # CHECK: %c42_i32 = arith.constant 42 : i32 | ||
| # CHECK-NEXT: call @"anchor{{.*}}"([[RANGE]], %c42_i32) | ||
| pair.second = 42 | ||
|
|
@@ -185,9 +154,34 @@ def test_jit_method(): | |
| # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} | ||
| scalar = 11 | ||
| # CHECK: [[V:%.*]]:2 = tt.call @"unpack{{.*}}"([[RANGE]], %c11_i32) | ||
| pair = pair_value_ctor(tl.arange(0, 4), scalar) | ||
| pair = Pair(tl.arange(0, 4), scalar) | ||
| a, b = pair.unpack() | ||
| # CHECK: call @anchor{{.*}}([[V]]#0) | ||
| anchor(a) | ||
| # CHECK: call @anchor{{.*}}([[V]]#1) | ||
| anchor(b) | ||
|
|
||
|
|
||
| @tl.core._aggregate | ||
| class TypeWithBuiltinInitializer: | ||
| value: tl.tensor | ||
|
|
||
| def __init__(self, _builder=None): | ||
| self.value = tl.arange(0, 4, _builder=_builder) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not so neat. I think before we can make this public we need the methods to automatically be treated as As is, this seems reasonable as an internal tool though. Maybe just rename to private (
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Happy to rename it for now. The problem with making these We would need a more involved fix that makes all
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There is also another trick somewhere that makes member methods automatically
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed to
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Even better. Would it be possible to use class without any annotations?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Depending on the semantics, maybe. The frontend would have to infer the annotations from the attributes set by the initializer |
||
|
|
||
| def modify(self, value, _builder=None): | ||
| self.value = value | ||
|
|
||
|
|
||
| @filecheck_test | ||
| @triton.jit | ||
| def test_aggregate_initializers(): | ||
| # CHECK-LABEL: test_aggregate_initializers | ||
| value = TypeWithBuiltinInitializer() | ||
| # CHECK: [[RANGE:%.*]] = tt.make_range {end = 4 : i32, start = 0 : i32} | ||
| # CHECK: call @"anchor{{.*}}"([[RANGE]]) | ||
| anchor(value) | ||
| # CHECK: [[RANGE:%.*]] = tt.make_range {end = 8 : i32, start = 4 : i32} | ||
| # CHECK: call @"anchor{{.*}}"([[RANGE]]) | ||
| value.modify(tl.arange(4, 8)) | ||
| anchor(value) | ||
Uh oh!
There was an error while loading. Please reload this page.