Skip to content

Conversation

@Mogball
Copy link
Collaborator

@Mogball Mogball commented May 28, 2025

📚 Stacked PRs 📚

This PR adds a @tl.aggregate decorator which, when placed on a Python class with field annotations, automatically generates a Triton base_type and base_value based on the class. It wraps the type in a Triton type and moves all the methods over.

This makes creating custom Triton types less verbose.

@Mogball Mogball marked this pull request as ready for review May 28, 2025 23:51
@Mogball Mogball requested a review from ptillet as a code owner May 28, 2025 23:51
value: tl.tensor

def __init__(self, _builder=None):
self.value = tl.arange(0, 4, _builder=_builder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not so neat. I think before we can make this public we need the methods to automatically be treated as @jit functions and not have to manually pass the builder around.

As is, this seems reasonable as an internal tool though. Maybe just rename to private (_aggregate).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Happy to rename it for now. The problem with making these @jit functions is that @jit arguments are not mutable through to the caller. I.e. doing self.first = 4 inside a JITFunction initializer won't actually change anything on the caller side.

We would need a more involved fix that makes all @jit methods automatically return self and patch the callsite to be x, other_results = x.method(*args)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is also another trick somewhere that makes member methods automatically @builtin, so we'd need to flip the convention for these types but that's not a big issue.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to _aggregate. I'll ponder the self mutability problem

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even better. Would it be possible to use class without any annotations?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Depending on the semantics, maybe. The frontend would have to infer the annotations from the attributes set by the initializer

Mogball added a commit that referenced this pull request May 29, 2025
…ods (#6963)

📚 Stacked PRs 📚

* #6970
* ➡️ #6963

This PR makes `base_value.method` return a BoundJITMethod which keeps
`base_value` to be passed as `__self__`.
Base automatically changed from mogball/jit_method to main May 29, 2025 02:27
@Mogball Mogball merged commit 4c7a5f4 into main May 29, 2025
8 checks passed
@Mogball Mogball deleted the mogball/dataclass branch May 29, 2025 16:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants