-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[Frontend] Add @tl.aggregate which autogenerates a Triton type
#6970
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| value: tl.tensor | ||
|
|
||
| def __init__(self, _builder=None): | ||
| self.value = tl.arange(0, 4, _builder=_builder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not so neat. I think before we can make this public we need the methods to automatically be treated as @jit functions and not have to manually pass the builder around.
As is, this seems reasonable as an internal tool though. Maybe just rename to private (_aggregate).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Happy to rename it for now. The problem with making these @jit functions is that @jit arguments are not mutable through to the caller. I.e. doing self.first = 4 inside a JITFunction initializer won't actually change anything on the caller side.
We would need a more involved fix that makes all @jit methods automatically return self and patch the callsite to be x, other_results = x.method(*args)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is also another trick somewhere that makes member methods automatically @builtin, so we'd need to flip the convention for these types but that's not a big issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Renamed to _aggregate. I'll ponder the self mutability problem
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even better. Would it be possible to use class without any annotations?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Depending on the semantics, maybe. The frontend would have to infer the annotations from the attributes set by the initializer
📚 Stacked PRs 📚
@tl.aggregatewhich autogenerates a Triton type #6970selftoJITFunctionwhen they are methods #6963This PR adds a
@tl.aggregatedecorator which, when placed on a Python class with field annotations, automatically generates a Tritonbase_typeandbase_valuebased on the class. It wraps the type in a Triton type and moves all the methods over.This makes creating custom Triton types less verbose.