-
Notifications
You must be signed in to change notification settings - Fork 2.6k
[AMD] Use v_perm instruction for convert_layout acceleration #9014
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] Use v_perm instruction for convert_layout acceleration #9014
Conversation
This PR introduces AMD specific ttg->llvm pattern which uses v_perm instructions instead of combinations of shifts and logical operations. Current limitations of this pattern: - Applied only for 8 bit data types; - Conversion required to be bijective; - No permutation across threads in workgroup.
d67fc4d to
72304fa
Compare
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@antiagainst happy to review this one if you guys want, but I'd need a bit of context on the semantics of the instruction.
| if (srcTy.getElementType().getIntOrFloatBitWidth() != 8) | ||
| return failure(); | ||
| // TODO: broadcasting is not supported at the moment. | ||
| if (!conversion.isInvertible()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You probably support warp / cta broadcasting tho, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it was present in input layout, yes.
This pattern improves cases when no inter-warp/thread communication happens, only permutations between registers. In such cases "conversion" contains only "register" input dim and we do not care if there are broadcasting in warps/lanes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah, right, I see. Then I would assume this would be better suited as an optimisation in LLVM, as we do tons of register packing and unpacking all across triton and we rely on LLVM / ptxas (in nvidia) to optimise it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, ideally it LLVM should be able to combine everything in optimal way, but in practice it is not good all the time. I've experimented and found that in some simple cases it succeeds, in other cases it combines permutations only partially, and sometimes llvm fallback to series of bit operations, which requires 3x more instructions that we expect with "optimal" approach.
For more details about "optimal pattern", see message below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right there are a lot of loose coupled patterns and passes in LLVM that can make it hard to sustain such lowering flow end-to-end always. So having a dedicated pattern to make sure we always emit the optimal code sounds good to me.
@FrederickVu will also take a look at this given he had some thoughts in #7809.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My point here is that sure, this is one of the cases where this pattern could be used, and if it's really useful for some real cases, it makes sense to add it. Now, more generally, this whole "unpack, reorder, pack again" pattern we do in quite a few places in triton and we just expect LLVM to generate nice code for us. This is why I was hinting at potentially having a look at improving the instcombine heuristics.
Will review the PR tomorrow tho.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup; actively looking at improving LLVM alongside too. :)
Details about this change for reviewersThis pattern aims to optimally convert intra-thread V_PERM instructionAMD GPUs (both CDNA and RDNA) have instruction Let's take a look at example: let's take indexing with mask bytes we get Fast way to shuffle datav_perm is very efficient if we need to shuffle bytes in registers, for example for in_thread_transpose optimization. Let's try to transpose 4x4xfp8 tensor. Each register contain 4 packed elements, so in total we need only 4 register: We can do this with v_perm instructions in two steps, first we combine halves of input registers in temporary values and then combine output values from pairs of these tmp values: Each temporary value requires one v_perm instruction to construct, and then each output register requires one instruction to combine two tmp values, total 8 v_perm instructions to transpose 4x4 tensor. If we do same with bit operations this will require 24 instructions. For every output register we need to extract bytes from each input register (requires 3 instructions) and then combine them back (3 bitwise fused or+shift instructions). This difference becomes even worse if registers hold more values. We have a kernel, which requires to shuffle 16x4xf8 tiles. which takes significant time. |
|
@lezcano This PR is ready for review, while we are look for ways to make it in LLVM. |
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sadly, I didn't get to review this one, and I'm starting holidays tomorrow for a few weeks, so feel free to review internally.
|
@lezcano |
- Assertion in transferWithVPerm that regBytes | numValues, this would fail if the layout conversion only has 2 8-bit elements per register. matchAndRewrite doesn't have a check to bail out in this case. - I can not make an in-thread layout convert which requires permutation, so I just added a test to cover simple 2 elements per thread case. - Small typos like "indeces" instead of "indices" and "mergable" instead of "mergeable". - fixed typos, reworked comment describing 4-way algorithm. - Redundant checks, like in processOneWayDependencies, needBytePermute should always evaluate to true since otherwise kRegister would not be present in the output of minimalCvtLayout(srcTy, dstTy). - This check is redundant for now, because we do not permit swizzling in registers. I've added them for generalization in case we permit such layouts. Consider following example: [[0, 1], [0, 2], [1, 1]] - Broadcasting can be handled as in the general layout conversion pathways. https://github.com/triton-lang/triton/blob/0cd582fe4645a146bd7c140806ecaae334fd676b/lib/Conversion/TritonGPUToLLVM/ConvertLayoutOpToLLVM.cpp#L149 - Approach from convert layout is not always applicable here, because broadcasting can happen inside one register. Applying layout transformation will probably lead to generation of bit operations this conversion pattern aims to avoid. I've tried to implement this generatlization, but it makes code even more complex, I did not see live examples, so I will leave this for future if we need it.
antiagainst
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Largely looks to me; just some minor comments.
@FrederickVu is looking at doing the optimization in LLVM proper. But it would take some time to fully land and incorporated. In the meanwhile we can have this to improve perf. @lezcano would you mind taking another look?
third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp
Outdated
Show resolved
Hide resolved
lezcano
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't go carefully through the maths, but overall looks reasonable to me.
…riton-lang#9014) This PR introduces AMD specific ttg->llvm pattern which uses v_perm instructions instead of combinations of shifts and logical operations. Limitations of this pattern: - Applied only for 8 bit data types; - Conversion required to be bijective; - No permutation across threads in workgroup. --------- Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>
…lang#9014) This PR introduces AMD specific ttg->llvm pattern which uses v_perm instructions instead of combinations of shifts and logical operations. Limitations of this pattern: - Applied only for 8 bit data types; - Conversion required to be bijective; - No permutation across threads in workgroup. --------- Co-authored-by: Alexander Efimov <efimov.alexander@gmail.com>


This PR introduces AMD specific ttg->llvm pattern which uses v_perm instructions instead of combinations of shifts and logical operations.
Limitations of this pattern: