Add implicit downcast in TMA descriptor store #6236
Merged
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This fixes a missing implicit downcast when storing blocks through TMA descriptors. Previously, attempting to widen the result of a descriptor load (e.g. from
float16tofloat32) and then store it back via the descriptor would result in an MLIR verification error because the block types no longer matched:The pointer/
tl.storepath already cast values to the target element type; descriptor stores should behave the same.Changes
descriptor_storeinpython/triton/language/semantic.pyto cast the incoming tensor to the descriptor's element type before emitting thecreate_descriptor_storeIR node.test_tensor_descriptor_store_downcasttopython/test/unit/cuda/test_experimental_tma.pywhich widens afloat16/bfloat16block tofloat32and stores it back via the descriptor.pre-commithooks to keep formatting consistent.A quick check under
TRITON_INTERPRET=1shows the new downcast path works:Checklist
pre-commit run --from-ref origin/main --to-ref HEAD.python/test.littests.This fix aligns descriptor stores with pointer store semantics and avoids an IR verifier failure when the stored block's element type is wider than the descriptor’s element type.