-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[Backend] Use byte permutes in intra-warp layout conversion #7809
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
cc @lezcano @apgoucher |
|
cc @lezcano Looks good performance-wise on the FlashAttention example kernel (and it helps fp16 not just fp8!):
I'm going to try the internal kernel benchmarks next |
|
@FrederickVu It's breaking some AMD tests -- are you assuming anywhere that the warp size is 32? |
|
@FrederickVu The problem may be that AMD's permute instruction uses the bytes rather than the nybbles of the selector mask to do the selection. Do we need to do something like: in order to convert from NVIDIA's convention to AMD's convention?
|
Good catch. I missed that. I didn't make any assumptions about warp size. |
@FrederickVu awesome -- the best place to put it is in the implementation of if such a transformation doesn't exist already. (This is way more preferable than adding extra complexity to the warp-shuffle code inside Triton.) |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is quite cool!
I'll defer to @apgoucher regarding the final implementation details, as he's already on top of these
…riton-lang#7809)" This reverts commit 58ae6f0.
@FrederickVu, sorry but I have to revert those 3 PRs #7809, #7825, #7861 There is a functional regression caused by #7809 but the other two PRs have many dependencies to it so I was not able to revert it cleanly separately and I couldn't manage to do a partial revert either. the follow convert layout miscompiles after with this PR: ``` #blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}> #mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}> %4 = ttg.convert_layout %3 : tensor<16x8xf8E5M2, #mma> -> tensor<16x8xf8E5M2, #blocked2> ``` it can be reproduced on Ampere, Hopper or Blackwell GPU (I would expect any nvidia gpu would show the problem) I can try to get a reproducer I can share later but would be nice to make a unit test for this convert layout in Gluon anyway. Happy to land those back when the bug is fixed, or if you manage to partial revert only the nvidia permute part that works too.
Reland triton-lang/triton#7809, triton-lang/triton#7825, triton-lang/triton#7861 Add a workaround for ptxas bug and add a regression test
Reland triton-lang/triton#7809, triton-lang/triton#7825, triton-lang/triton#7861 Add a workaround for ptxas bug and add a regression test
Reland triton-lang/triton#7809, triton-lang/triton#7825, triton-lang/triton#7861 Add a workaround for ptxas bug and add a regression test
Reland triton-lang/triton#7809, triton-lang/triton#7825, triton-lang/triton#7861 Add a workaround for ptxas bug and add a regression test

Issue
We patch an oversight in #7558 where reindexing sub-32-bit elements before or after unpacking them from vectors can cause LLVM’s InstCombine to materialize
shufflevectors in real kernels that lower to byte permute instructions, which are not optimized away.This was believed to cause a small regression in #7574. In the context of that PR, one has 8-bit elements packed into registers and a layout conversion described by the permutation
(r1 r2 l1 l0)of register (r*) and lane (l*) basis vectors. Due to register packing,r1corresponds to an intra-register index bit.The current algorithm interprets
implements
(r2 l1)(l0 l1)using aselect-shuffle-selectpattern, and then applies(r2 r1)by reindexing the elements after extraction. As the elements are immediately repacked, InstCombine producesshufflevectorinstructions from the extract-insert pattern, resulting in oneprmtper packed register.Fix
It is possible to fuse the effects of these intra-register index bit permutations to the first and/or third stages of the
select-shuffle-selectpattern of the conversion algorithm. In most cases, this happens when in the cycle decomposition of the layout conversion, the intra-register index bit is adjacent to a lane index bit within a cycle, as in the above example.Edit after the merge:
Here's a comparison of the PTX and SASS from the LIT test of the manual and general implementations of the layout conversion in the aforementioned PR: enjoy.
Future work
To the best of my knowledge, this PR handles all cases where the above fusion is possible. However, there are cases where it is not possible which have potential for further optimization due to InstCombine’s lack of coverage:
Suppose we have four input
v4i8s whose elements are rearranged via extraction and insertion into four outputv4i8s in a manner such that each output vector contains one element from each of the four input vectors. In this case, LLVM generates 4 chains of 3prmts to build the output vectors, but it is possible to carry this out using two stages of 4 independentprmts, thus reducing depth and instruction count.This pattern can also exist in layout conversions that take the
transferWithinWarppath, but as it is truly an intra-thread pattern, this optimization should be implemented intransferWithinThreadand invoked withintransferWithinWarpin a future PR.