Skip to content

Conversation

@jhapradip
Copy link
Contributor

New contributor declaration

  • I am not making a trivial change, such as fixing a typo in a comment.

  • I have written a PR description following these
    rules.

  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.

  • Select one of the following.

    • I have added tests.
      • /test for lit tests
      • /unittest for C++ tests
      • /python/test for end-to-end tests
    • This PR does not need a test because FILL THIS IN.
  • Select one of the following.

    • I have not added any lit tests.
    • The lit tests I have added follow these best practices,
      including the "tests should be minimal" section. (Usually running Python code
      and using the instructions it generates is not minimal.)

@jhapradip jhapradip requested a review from ptillet as a code owner May 3, 2025 04:39
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks cool, few comments

// mask out the values for which input mask is invalid
binMask = b.and_(binMask, inputMaskBit);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it looks like every loop iteration will AND the same value?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

inputMaskBit is loop invariant, but binMask is updated within the loop.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but we are just masking out bits in the loop? Would it be equivalent to apply it once to mask outside the loop?

// mask out the values for which input mask is invalid
binMask = b.and_(binMask, inputMaskBit);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right but we are just masking out bits in the loop? Would it be equivalent to apply it once to mask outside the loop?

Comment on lines 121 to 125
auto mask = op.getMask();
if (mask)
return failure();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have a great solution but not having this optimization is likely to cause poor performance whenever the histogram with mask is used.
One thing we could do is create a convert_layout for the mask instead, still not always ideal but I wonder if it is more likely to be better overall

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right about the other one. I fixed the code. Thanks for your suggestion.

I am not very familiar with how layout conversions are handled. The current code assumes that both operands are of the same layout so that layout of the mask matches the layout of src. I am not entirely sure how Triton picks the layout for these. My understanding is the current code removes any conversions because it can compute the histogram in any order. When the mask is applied, the operation can still happen in any layout as long as both of these operands match in layout.

I think the current code ensures this by keeping any conversions if the mask is present and this can be inefficient as you pointed out. If both operands are being converted, I think you are suggesting that we should only convert the mask. If so, is there a good way to match the layout of the src? Do we implement this manually or is there a trait to indicate this? Another case I am curious about is if the src and mask have the same original layout and then both have (the same) conversion applied to each. Is this something we need to handle or would it be handled automatically by other layout passes?

I guess I am looking for idiomatic ways to represent layout constraints in Triton. Secondarily, I would like to confirm what should be converted. I assume ideally only the mask should be converted.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux I revisited this PR and as per your suggestion added code to convert layout of mask to match the layout of src. As mentioned in the comments, I think the best we can do is a single conversion to have the mask match the src. Please let me know if I should do something else.

Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, one more comment on the pattern with convert layout. After that it should be good to land.

Comment on lines 123 to 135
// Retrieve ancestor of a value before all conversions
auto getAncestorBeforeConversions = [](auto value) {
auto numConversions = 0;
while (auto convert = value.template getDefiningOp<ConvertLayoutOp>()) {
numConversions++;
value = convert.getSrc();
}
return std::make_pair(value, numConversions);
};

auto [src, numSrcConversions] = getAncestorBeforeConversions(op.getSrc());

// If there is no mask, replace the src directly
if (!op.getMask()) {
if (!numSrcConversions)
return failure();

rewriter.replaceOpWithNewOp<triton::HistogramOp>(
op, op->getResult(0).getType(), src, op.getMask());
return success();
}

// When mask is present, we want a single conversion on mask to match src's
// layout. If there are more conversions, delete them.
auto [mask, numMaskConversions] =
getAncestorBeforeConversions(op.getMask());

auto sharedType = getI1SameShape(src.getType());
if (numSrcConversions || numMaskConversions > 1) {
rewriter.setInsertionPoint(op);
mask = rewriter.create<ConvertLayoutOp>(op.getLoc(), sharedType, mask);
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
op, op->getResult(0).getType(), src, mask);
return success();
} else {
return failure();
rewriter.replaceOpWithNewOp<triton::HistogramOp>(
op, op->getResult(0).getType(), convert.getSrc());
return mlir::success();
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this looks overly complicated, can we just keep the original logic and if a mask exist create a convert layout for it?
Also could you add a simple lit test for this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I simplified the convert logic for the mask and added a lit test.

jhapradip added 3 commits June 3, 2025 12:04
Currently, Triton handles tensors with dimensions that are powers of
two. This limitation makes it difficult to create histograms for
irregularly sized tensors. Other operations such as load and store work
around this limitation using a mask parameter. This commit adds similar
support for histograms. Fixes triton-lang#4825.
Copy link
Collaborator

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the great work

@ThomasRaoux ThomasRaoux merged commit 2a10b48 into triton-lang:main Jun 3, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants