-
Notifications
You must be signed in to change notification settings - Fork 2.6k
Add support for masked histograms #6695
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks cool, few comments
| // mask out the values for which input mask is invalid | ||
| binMask = b.and_(binMask, inputMaskBit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like every loop iteration will AND the same value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
inputMaskBit is loop invariant, but binMask is updated within the loop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right but we are just masking out bits in the loop? Would it be equivalent to apply it once to mask outside the loop?
| // mask out the values for which input mask is invalid | ||
| binMask = b.and_(binMask, inputMaskBit); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
right but we are just masking out bits in the loop? Would it be equivalent to apply it once to mask outside the loop?
lib/Dialect/TritonGPU/IR/Ops.cpp
Outdated
| auto mask = op.getMask(); | ||
| if (mask) | ||
| return failure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't have a great solution but not having this optimization is likely to cause poor performance whenever the histogram with mask is used.
One thing we could do is create a convert_layout for the mask instead, still not always ideal but I wonder if it is more likely to be better overall
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You are right about the other one. I fixed the code. Thanks for your suggestion.
I am not very familiar with how layout conversions are handled. The current code assumes that both operands are of the same layout so that layout of the mask matches the layout of src. I am not entirely sure how Triton picks the layout for these. My understanding is the current code removes any conversions because it can compute the histogram in any order. When the mask is applied, the operation can still happen in any layout as long as both of these operands match in layout.
I think the current code ensures this by keeping any conversions if the mask is present and this can be inefficient as you pointed out. If both operands are being converted, I think you are suggesting that we should only convert the mask. If so, is there a good way to match the layout of the src? Do we implement this manually or is there a trait to indicate this? Another case I am curious about is if the src and mask have the same original layout and then both have (the same) conversion applied to each. Is this something we need to handle or would it be handled automatically by other layout passes?
I guess I am looking for idiomatic ways to represent layout constraints in Triton. Secondarily, I would like to confirm what should be converted. I assume ideally only the mask should be converted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ThomasRaoux I revisited this PR and as per your suggestion added code to convert layout of mask to match the layout of src. As mentioned in the comments, I think the best we can do is a single conversion to have the mask match the src. Please let me know if I should do something else.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, one more comment on the pattern with convert layout. After that it should be good to land.
lib/Dialect/TritonGPU/IR/Ops.cpp
Outdated
| // Retrieve ancestor of a value before all conversions | ||
| auto getAncestorBeforeConversions = [](auto value) { | ||
| auto numConversions = 0; | ||
| while (auto convert = value.template getDefiningOp<ConvertLayoutOp>()) { | ||
| numConversions++; | ||
| value = convert.getSrc(); | ||
| } | ||
| return std::make_pair(value, numConversions); | ||
| }; | ||
|
|
||
| auto [src, numSrcConversions] = getAncestorBeforeConversions(op.getSrc()); | ||
|
|
||
| // If there is no mask, replace the src directly | ||
| if (!op.getMask()) { | ||
| if (!numSrcConversions) | ||
| return failure(); | ||
|
|
||
| rewriter.replaceOpWithNewOp<triton::HistogramOp>( | ||
| op, op->getResult(0).getType(), src, op.getMask()); | ||
| return success(); | ||
| } | ||
|
|
||
| // When mask is present, we want a single conversion on mask to match src's | ||
| // layout. If there are more conversions, delete them. | ||
| auto [mask, numMaskConversions] = | ||
| getAncestorBeforeConversions(op.getMask()); | ||
|
|
||
| auto sharedType = getI1SameShape(src.getType()); | ||
| if (numSrcConversions || numMaskConversions > 1) { | ||
| rewriter.setInsertionPoint(op); | ||
| mask = rewriter.create<ConvertLayoutOp>(op.getLoc(), sharedType, mask); | ||
| rewriter.replaceOpWithNewOp<triton::HistogramOp>( | ||
| op, op->getResult(0).getType(), src, mask); | ||
| return success(); | ||
| } else { | ||
| return failure(); | ||
| rewriter.replaceOpWithNewOp<triton::HistogramOp>( | ||
| op, op->getResult(0).getType(), convert.getSrc()); | ||
| return mlir::success(); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this looks overly complicated, can we just keep the original logic and if a mask exist create a convert layout for it?
Also could you add a simple lit test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I simplified the convert logic for the mask and added a lit test.
Currently, Triton handles tensors with dimensions that are powers of two. This limitation makes it difficult to create histograms for irregularly sized tensors. Other operations such as load and store work around this limitation using a mask parameter. This commit adds similar support for histograms. Fixes triton-lang#4825.
ThomasRaoux
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for the great work
New contributor declaration
I am not making a trivial change, such as fixing a typo in a comment.
I have written a PR description following these
rules.
I have run
pre-commit run --from-ref origin/main --to-ref HEAD.Select one of the following.
/testforlittests/unittestfor C++ tests/python/testfor end-to-end testsFILL THIS IN.Select one of the following.
littests.littests I have added follow these best practices,including the "tests should be minimal" section. (Usually running Python code
and using the instructions it generates is not minimal.)