Skip to content

Conversation

@FrederickVu
Copy link
Contributor

@FrederickVu FrederickVu commented Aug 10, 2025

Issue

We patch an oversight in #7558 where reindexing sub-32-bit elements before or after unpacking them from vectors can cause LLVM’s InstCombine to materialize shufflevectors in real kernels that lower to byte permute instructions, which are not optimized away.

This was believed to cause a small regression in #7574. In the context of that PR, one has 8-bit elements packed into registers and a layout conversion described by the permutation (r1 r2 l1 l0) of register (r*) and lane (l*) basis vectors. Due to register packing, r1 corresponds to an intra-register index bit.

The current algorithm interprets

(r1 r2 l1 l0) = (r2 r1) * (r2 l1)(l0 l1),

implements (r2 l1)(l0 l1) using a select-shuffle-select pattern, and then applies (r2 r1) by reindexing the elements after extraction. As the elements are immediately repacked, InstCombine produces shufflevector instructions from the extract-insert pattern, resulting in one prmt per packed register.

Fix

It is possible to fuse the effects of these intra-register index bit permutations to the first and/or third stages of the select-shuffle-select pattern of the conversion algorithm. In most cases, this happens when in the cycle decomposition of the layout conversion, the intra-register index bit is adjacent to a lane index bit within a cycle, as in the above example.

Edit after the merge:
Here's a comparison of the PTX and SASS from the LIT test of the manual and general implementations of the layout conversion in the aforementioned PR: enjoy.

Future work

To the best of my knowledge, this PR handles all cases where the above fusion is possible. However, there are cases where it is not possible which have potential for further optimization due to InstCombine’s lack of coverage:

Suppose we have four input v4i8s whose elements are rearranged via extraction and insertion into four output v4i8s in a manner such that each output vector contains one element from each of the four input vectors. In this case, LLVM generates 4 chains of 3 prmts to build the output vectors, but it is possible to carry this out using two stages of 4 independent prmts, thus reducing depth and instruction count.

This pattern can also exist in layout conversions that take the transferWithinWarp path, but as it is truly an intra-thread pattern, this optimization should be implemented in transferWithinThread and invoked within transferWithinWarp in a future PR.

@FrederickVu
Copy link
Contributor Author

cc @lezcano @apgoucher
The decomposition algorithm is now admittedly unwieldy. Let me know if/how you'd like me to refactor, and feel free to edit it however you want in the meantime.

@apgoucher
Copy link
Collaborator

cc @lezcano Looks good performance-wise on the FlashAttention example kernel (and it helps fp16 not just fp8!):

  • main:
fused-attention-batch4-head32-d64-fwd-causal=True-warp_specialize=False:
     N_CTX  Triton [FP16]  Triton [FP8]     Flash-2
0   1024.0     242.314649    285.814019  183.118205
1   2048.0     334.869910    394.962538  249.059750
2   4096.0     384.453054    459.962942  294.297099
3   8192.0     405.544529    494.332730  308.856540
4  16384.0     407.918684    513.261780  318.415601
fused-attention-batch4-head32-d64-fwd-causal=False-warp_specialize=False:
     N_CTX  Triton [FP16]  Triton [FP8]     Flash-2
0   1024.0     342.767076    407.895748  266.649268
1   2048.0     407.227630    470.817880  305.191436
2   4096.0     421.114108    501.936374  315.539498
3   8192.0     433.897669    517.116784  321.407416
4  16384.0     432.544829    523.620054  328.799751
  • your branch:
fused-attention-batch4-head32-d64-fwd-causal=True-warp_specialize=False:
     N_CTX  Triton [FP16]  Triton [FP8]     Flash-2
0   1024.0     242.816853    290.647775  183.833216
1   2048.0     334.322730    399.589327  247.832698
2   4096.0     395.088091    465.403247  294.341343
3   8192.0     395.751063    501.962617  310.237041
4  16384.0     427.227385    520.040330  321.363936
fused-attention-batch4-head32-d64-fwd-causal=False-warp_specialize=False:
     N_CTX  Triton [FP16]  Triton [FP8]     Flash-2
0   1024.0     342.953512    414.230484  269.494227
1   2048.0     406.708900    478.341362  305.030990
2   4096.0     433.202112    511.074539  317.313660
3   8192.0     429.079986    525.890151  321.108646
4  16384.0     446.499720    532.796369  328.831355

I'm going to try the internal kernel benchmarks next

@apgoucher
Copy link
Collaborator

@FrederickVu It's breaking some AMD tests -- are you assuming anywhere that the warp size is 32?

@apgoucher
Copy link
Collaborator

@FrederickVu The problem may be that AMD's permute instruction uses the bytes rather than the nybbles of the selector mask to do the selection. Do we need to do something like:

prmt_mask = (prmt_mask & 0x000000ff) | ((prmt_mask & 0x0000ff00) << 8); // 0x0000ffff --> 0x00ff00ff
prmt_mask = (prmt_mask & 0x000f000f) | ((prmt_mask & 0x00f000f0) << 4); // 0x00ff00ff --> 0x0f0f0f0f

in order to convert from NVIDIA's convention to AMD's convention?

image

@FrederickVu
Copy link
Contributor Author

@FrederickVu The problem may be that AMD's permute instruction uses the bytes rather than the nybbles of the selector mask to do the selection. Do we need to do something like:

prmt_mask = (prmt_mask & 0x000000ff) | ((prmt_mask & 0x0000ff00) << 8); // 0x0000ffff --> 0x00ff00ff
prmt_mask = (prmt_mask & 0x000f000f) | ((prmt_mask & 0x00f000f0) << 4); // 0x00ff00ff --> 0x0f0f0f0f

in order to convert from NVIDIA's convention to AMD's convention?

Good catch. I missed that. I didn't make any assumptions about warp size.

@apgoucher
Copy link
Collaborator

@FrederickVu The problem may be that AMD's permute instruction uses the bytes rather than the nybbles of the selector mask to do the selection. Do we need to do something like:

prmt_mask = (prmt_mask & 0x000000ff) | ((prmt_mask & 0x0000ff00) << 8); // 0x0000ffff --> 0x00ff00ff
prmt_mask = (prmt_mask & 0x000f000f) | ((prmt_mask & 0x00f000f0) << 4); // 0x00ff00ff --> 0x0f0f0f0f

in order to convert from NVIDIA's convention to AMD's convention?

Good catch. I missed that. I didn't make any assumptions about warp size.

@FrederickVu awesome -- the best place to put it is in the implementation of permute inside third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp; that way whenever someone uses permute inside Triton they don't need to worry about the target hardware. And then AMD can always write a compiler pass to commute pure functions through selects when both arguments are constants:

f(pred ? c0 : c1) --> (pred ? f(c0) : f(c1))

if such a transformation doesn't exist already. (This is way more preferable than adding extra complexity to the warp-shuffle code inside Triton.)

Copy link
Contributor

@lezcano lezcano left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is quite cool!

I'll defer to @apgoucher regarding the final implementation details, as he's already on top of these

@lezcano lezcano merged commit 58ae6f0 into triton-lang:main Aug 11, 2025
9 checks passed
ThomasRaoux added a commit to ThomasRaoux/triton that referenced this pull request Aug 19, 2025
ThomasRaoux added a commit that referenced this pull request Aug 19, 2025
@FrederickVu, sorry but I have to revert those 3 PRs
#7809,
#7825,
#7861

There is a functional regression caused by
#7809 but the other two PRs
have many dependencies to it so I was not able to revert it cleanly
separately and I couldn't manage to do a partial revert either.

the follow convert layout miscompiles after with this PR:
```
#blocked2 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [4, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
#mma = #ttg.nvidia_mma<{versionMajor = 2, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 8]}>
%4 = ttg.convert_layout %3 : tensor<16x8xf8E5M2, #mma> -> tensor<16x8xf8E5M2, #blocked2>
```

it can be reproduced on Ampere, Hopper or Blackwell GPU (I would expect
any nvidia gpu would show the problem)

I can try to get a reproducer I can share later but would be nice to
make a unit test for this convert layout in Gluon anyway.

Happy to land those back when the bug is fixed, or if you manage to
partial revert only the nvidia permute part that works too.
ThomasRaoux added a commit that referenced this pull request Aug 22, 2025
Reland #7809,
#7825,
#7861

Add a workaround for ptxas bug and add a regression test
januszjah pushed a commit to intel/intel-xpu-backend-for-triton that referenced this pull request Sep 11, 2025
januszjah pushed a commit to intel/intel-xpu-backend-for-triton that referenced this pull request Sep 15, 2025
januszjah pushed a commit to intel/intel-xpu-backend-for-triton that referenced this pull request Sep 16, 2025
januszjah pushed a commit to intel/intel-xpu-backend-for-triton that referenced this pull request Sep 16, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants