Skip to content

Conversation

@codex-maintainers
Copy link
Contributor

Description

This fixes a missing implicit downcast when storing blocks through TMA descriptors. Previously, attempting to widen the result of a descriptor load (e.g. from float16 to float32) and then store it back via the descriptor would result in an MLIR verification error because the block types no longer matched:

# ptr.element_ty is tl.float16
desc = tl._experimental_make_tensor_descriptor(ptr, shape=.., strides=..., block_shape=...)
value = desc.load([off_x, off_y]).to(tl.float32)
# 'tt.experimental_descriptor_store' op tensor desciptor block and tensor types must match
desc.store([off_x, off_y], value)

The pointer/tl.store path already cast values to the target element type; descriptor stores should behave the same.

Changes

  • Updated descriptor_store in python/triton/language/semantic.py to cast the incoming tensor to the descriptor's element type before emitting the create_descriptor_store IR node.
  • Added a regression test test_tensor_descriptor_store_downcast to python/test/unit/cuda/test_experimental_tma.py which widens a float16/bfloat16 block to float32 and stores it back via the descriptor.
  • Ran pre-commit hooks to keep formatting consistent.
    A quick check under TRITON_INTERPRET=1 shows the new downcast path works:
True  # torch.equal(a, out) when storing a widened float16 block
True  # bfloat16 as well

Checklist

  • I am not making a trivial change, such as fixing a typo in a comment.
  • I have written a PR description following these rules.
  • I have run pre-commit run --from-ref origin/main --to-ref HEAD.
  • I have added tests under python/test.
  • I have not added any lit tests.

This fix aligns descriptor stores with pointer store semantics and avoids an IR verifier failure when the stored block's element type is wider than the descriptor’s element type.

@Jokeren
Copy link
Contributor

Jokeren commented Mar 18, 2025

@Mogball @gopoto Should the descriptor store be fully explicit or we can allow implicit cast?

@codex-maintainers codex-maintainers force-pushed the submission/6178/0304/1741838638 branch from 398351d to ca8f8cf Compare March 19, 2025 01:07
@ThomasRaoux
Copy link
Collaborator

@Mogball @gopoto Should the descriptor store be fully explicit or we can allow implicit cast?

I believe we want to allow implicit cast, this matches the behavior of tl.store

@ThomasRaoux ThomasRaoux force-pushed the submission/6178/0304/1741838638 branch from ca8f8cf to 20c80a7 Compare August 21, 2025 23:54
@ThomasRaoux ThomasRaoux force-pushed the submission/6178/0304/1741838638 branch from 20c80a7 to 4e11300 Compare August 21, 2025 23:55
@ThomasRaoux ThomasRaoux merged commit 286e91f into triton-lang:main Aug 22, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants