-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[Gluon] Implement attention kernels for d64 and d128 #7009
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
784d43d to
2096971
Compare
b8cde4d to
9fed043
Compare
97ed0fd to
984460f
Compare
40eea56 to
e98c5ff
Compare
| self.ready_bars = ready_bars | ||
| self.empty_bars = empty_bars | ||
| self.num_buffers = gl.constexpr(num_buffers) | ||
| self.num_consumers = gl.constexpr(num_consumers) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder, would an annotated assignment work here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No because this is being executed as Python code :/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would love ideas on how to improve the syntax here. There are a couple of other places where explicit wrapping of values with gl.constexpr is needed, even inside Gluon code
| @gluon.jit | ||
| def release(self): | ||
| if isinstance(self.mem, gl.shared_memory_descriptor): | ||
| self.mem._keep_alive() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to add _keep_alive for tensor memory descriptors as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is no corresponding operation for it, but thinking about it, we might need one
```
fused-attention-batch4-head32-d64-fwd-causal=True-config=Config(warp_specialize=True, use_gluon=False):
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 206.903414 226.941886
1 2048.0 397.637820 397.894749
2 4096.0 538.967275 537.617627
3 8192.0 633.806800 630.562116
4 16384.0 690.005849 685.848451
fused-attention-batch4-head32-d64-fwd-causal=True-config=Config(warp_specialize=False, use_gluon=True):
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 158.854438 167.127676
1 2048.0 430.994165 445.445875
2 4096.0 574.152492 593.125123
3 8192.0 669.977292 692.156473
4 16384.0 725.290898 749.413157
```
```
fused-attention-batch4-head32-d128-fwd-causal=True-config=Config(warp_specialize=True, use_gluon=False):
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 168.101257 180.147833
1 2048.0 232.412164 251.379866
2 4096.0 282.099851 306.706495
3 8192.0 313.280068 342.515038
4 16384.0 331.734557 362.578482
fused-attention-batch4-head32-d128-fwd-causal=True-config=Config(warp_specialize=False, use_gluon=True):
N_CTX Triton [FP16] Triton [FP8]
0 1024.0 153.825955 303.351930
1 2048.0 504.991716 532.393691
2 4096.0 723.003324 750.103002
3 8192.0 888.283770 924.001097
4 16384.0 979.523873 1045.632286
```
peterbell10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Final perf numbers for this. The gap for D64 is quite big now since cublas appears to have improved perf a lot from 720->920. For D128, most of the gap esp for small ctx is making the kernel persistent. |
No description provided.