Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
6696f55
hacky but works
ptillet May 27, 2025
3c6aee3
tinkering with datastructure
ptillet May 27, 2025
4ccd22b
.
ptillet May 28, 2025
b415a63
Merge remote-tracking branch 'origin/main' into phil/kernels/routing-…
ptillet May 28, 2025
7725eb3
cleanup
ptillet May 29, 2025
5b0e9e7
working bwd pass
ptillet May 30, 2025
1631935
fix bug
ptillet May 31, 2025
151e0fa
.
ptillet May 31, 2025
2cf849c
added missing file
ptillet May 31, 2025
666e9b4
clean-up ExptData
ptillet May 31, 2025
b6be9b4
.
ptillet May 31, 2025
528d6ba
.
ptillet May 31, 2025
2f0d281
Merge remote-tracking branch 'origin/main' into phil/kernels/routing-…
ptillet Jun 1, 2025
4eabc69
rename renormalize=True -> sm_first=False
ptillet Jun 1, 2025
5191300
cleanup
ptillet Jun 2, 2025
624bc57
some renaming
ptillet Jun 2, 2025
a135e36
more renaming
ptillet Jun 2, 2025
6e11f39
more comments
ptillet Jun 2, 2025
62182ce
more cleaning
ptillet Jun 2, 2025
f676bbb
finish moving expt_data to routing_data
ptillet Jun 2, 2025
741d9a8
small fix
ptillet Jun 2, 2025
9034511
more cleaning
ptillet Jun 2, 2025
f10eae8
.
ptillet Jun 2, 2025
c6d9396
sort gates in ascending expert index
ptillet Jun 3, 2025
326db75
.
ptillet Jun 3, 2025
96dac7d
.
ptillet Jun 3, 2025
09d9381
.
ptillet Jun 3, 2025
0118953
.
ptillet Jun 4, 2025
0ba302e
.
ptillet Jun 5, 2025
b9dffe0
.
ptillet Jun 5, 2025
1db970a
Incorporate changes from PR 7063
apgoucher Jun 5, 2025
f4f57d9
Merge branch 'main' into phil/kernels/routing-update
apgoucher Jun 5, 2025
9b5e3b6
fix typos
apgoucher Jun 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 43 additions & 56 deletions python/triton_kernels/tests/test_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,81 +2,64 @@
import torch
from triton_kernels.routing import routing, routing_torch
from triton_kernels.testing import assert_close
from triton_kernels.matmul_ogs_details.metadata import compute_metadata
from triton_kernels.testing import assert_equal


def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"):
def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"):
logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True)
return logits


def ref_expt_data(routing_data, n_gates, block_m):
hist = routing_data.expt_hist
n_expts_tot = routing_data.n_expts_tot
blks = (hist + block_m - 1) // block_m # matmul blocks needed
tsum = torch.cumsum(hist, dim=0) # prefix sum of tokens
bsum = torch.cumsum(blks, dim=0) # prefix sum of blocks
# Get the max number of matmul blocks of size d_tile needed (and is launched with).
# This assumes the worst distribution of all experts with one token except for one that has the rest.
if n_gates <= n_expts_tot:
grid_m = n_gates
else:
# ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1
# ceil_div(x, y): -(-x // y)
grid_m = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m)
bloc_data = -torch.ones(grid_m, dtype=torch.int32)
# compute data required to drive ragged batch matmul
for e in range(n_expts_tot):
offset = bsum[e - 1] if e else 0
for b in range(blks[e]):
bloc_data[offset + b] = (b << 16) + e

expt_data = torch.zeros(n_expts_tot * 3 + 2 + grid_m, dtype=torch.int32, device=hist.device)
expt_data[:n_expts_tot] = routing_data.expt_hist
expt_data[n_expts_tot + 1:n_expts_tot * 2 + 1] = tsum
expt_data[n_expts_tot * 2 + 2:n_expts_tot * 3 + 2] = bsum
expt_data[n_expts_tot * 3 + 2:] = bloc_data
return expt_data


@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024])
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)])
@pytest.mark.parametrize("block_m", [64, 128])
n_tokens = [(x, None) for x in [371, 255, 256, 4096, 1023, 1024]]
n_tokens += [(1152, 911)]


@pytest.mark.parametrize("n_tokens_pad, n_tokens_raw", n_tokens)
@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 32), (1500, 8)])
@pytest.mark.parametrize("use_expt_indx", [False, True])
@pytest.mark.parametrize("renormalize", [True, False])
def test_op(n_tokens, n_expts_tot, n_expts_act, renormalize, block_m, use_expt_indx, device):
@pytest.mark.parametrize("sm_first", [True, False])
def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_expt_indx, device):
torch.manual_seed(2)
tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach()
ref_logits = tri_logits.clone()
if n_tokens_raw is None:
n_tokens_raw = n_tokens_pad
n_routing_rows = None
else:
n_routing_rows = torch.tensor([n_tokens_raw], dtype=torch.int32, device=device)
n_gates_raw = n_tokens_raw * n_expts_act
tri_logits = init_data(n_tokens_pad, n_expts_tot, device=device).detach()
tri_logits[n_tokens_raw:, :] = float("inf") # should not be used
tri_logits = tri_logits.requires_grad_(True)
ref_logits = tri_logits.clone().detach().requires_grad_(True)

if use_expt_indx:
rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64)
tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens)])
tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens_pad)])
tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1)
ref_expt_indx = tri_expt_indx[:n_tokens]
tri_expt_indx[n_tokens_raw:] = -99999 # should not be used
ref_expt_indx = tri_expt_indx[:n_tokens_raw]
else:
tri_expt_indx = ref_expt_indx = None
if not renormalize:
tri_logits = torch.softmax(tri_logits, dim=-1)
ref_logits = torch.softmax(ref_logits, dim=-1)
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, renormalize, ref_expt_indx)
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, renormalize, tri_expt_indx)
ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m)
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m)
ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, sm_first, ref_expt_indx,
n_rows=n_routing_rows)
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, sm_first, tri_expt_indx,
n_rows=n_routing_rows)

def _assert_indx_equal(ref, tri):
assert_equal(ref, tri[:len(ref)])
assert torch.all(tri[len(ref):] == -1)

# print((ref_routing_data.expt_hist != tri_routing_data.expt_hist).nonzero())
# breakpoint()
assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3)
assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal[:n_gates_raw], 2e-2, 4e-3)
assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist)

assert_equal(ref_metadata[:n_expts_tot], tri_metadata.hist)
assert_equal(ref_metadata[n_expts_tot:2 * n_expts_tot + 1], tri_metadata.offs)
assert_equal(ref_metadata[3 * n_expts_tot + 1], tri_metadata.offs_sum)
assert_equal(ref_metadata[3 * n_expts_tot + 2:], tri_metadata.blocks)
ref_expt_data = ref_routing_data.expt_data
tri_expt_data = tri_routing_data.expt_data
assert_equal(ref_expt_data.hist, tri_expt_data.hist)
assert_equal(ref_expt_data.token_offs_raw, tri_expt_data.token_offs_raw)
assert len(ref_expt_data.token_offs_pad) == len(tri_expt_data.token_offs_pad)
assert len(ref_expt_data.block_pid_map) == len(tri_expt_data.block_pid_map)
for block_m in ref_expt_data.token_offs_pad.keys():
assert_equal(ref_expt_data.token_offs_pad[block_m], tri_expt_data.token_offs_pad[block_m])
assert_equal(ref_expt_data.block_pid_map[block_m], tri_expt_data.block_pid_map[block_m])

assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot
assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act
Expand All @@ -86,18 +69,22 @@ def _assert_indx_equal(ref, tri):
_assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx)
_assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx)

scales_grad = torch.randn_like(tri_routing_data.gate_scal)
ref_routing_data.gate_scal.backward(scales_grad[:n_gates_raw])
tri_routing_data.gate_scal.backward(scales_grad)

assert_close(ref_logits.grad[:n_tokens_raw], tri_logits.grad[:n_tokens_raw])


def bench_routing():
import triton.profiler as proton
n_tokens = 8192
block_m = 128
n_expts_tot, n_expts_act = 128, 4
tri_logits = init_data(n_tokens, n_expts_tot)
proton.start("routing")
proton.activate()
for i in range(100):
tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act)
tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) # noqa: F841
proton.finalize()
try:
import os
Expand Down
33 changes: 0 additions & 33 deletions python/triton_kernels/triton_kernels/bitmatrix.py

This file was deleted.

4 changes: 2 additions & 2 deletions python/triton_kernels/triton_kernels/compaction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch
from .compaction_details._masked_compaction import _masked_compaction
from .bitmatrix import Bitmatrix
from .datastruct import Bitmatrix


def compaction(yv, yi, bitmask, sentinel=-1):
Expand Down Expand Up @@ -33,7 +33,7 @@ def compaction(yv, yi, bitmask, sentinel=-1):
ret_yv = torch.empty_like(yv)
ret_yi = torch.empty_like(yi)
if isinstance(bitmask, Bitmatrix):
bitmask = bitmask.data
bitmask = bitmask.handle

_masked_compaction[(n_rows, )](
yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs
Expand Down
48 changes: 48 additions & 0 deletions python/triton_kernels/triton_kernels/datastruct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
import torch
from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows


class Tensor:

def __init__(self, handle, shape_raw, shape_pad=None):
self.handle = handle
self.ndim = handle.ndim
self.dtype = handle.dtype
self.device = handle.device
self.shape_pad = handle.shape if shape_pad is None else shape_pad
self.shape_raw = shape_raw

def stride(self, *args):
return self.handle.stride(*args)

def data_ptr(self):
return self.handle.data_ptr()


class Bitmatrix(Tensor):
"""
Represents a boolean matrix in a packed format where each element occupies
a single bit of memory.

_scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along
with the actual bitmatrix to avoid having to launch a separate memset
kernel when we call Bitmatrix::sum().
"""

_scratchpad: torch.Tensor

def __init__(self, handle, shape_raw, scratchpad=None):
assert handle.ndim == 2
shape_pad = [handle.shape[0], handle.shape[1] * 32]
super().__init__(handle, shape_raw, shape_pad)
assert self.dtype == torch.uint32
self._scratchpad = scratchpad

def sum(self, partials_block_size):
_, n_cols = self.shape_raw
dev = self.device
if self._scratchpad is None:
self._scratchpad = clear_sums(n_cols, dev)
out_ret = self._scratchpad[:n_cols]
self._scratchpad = None # throw error if we try to sum again
return sum_bitmatrix_rows(self, out_ret, partials_block_size, self.shape_raw[0])
20 changes: 10 additions & 10 deletions python/triton_kernels/triton_kernels/matmul_ogs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn
from .matmul_ogs_details._finalize_matmul import _finalize_matmul
from .matmul_ogs_details.opt_flags import make_opt_flags
from .matmul_ogs_details.metadata import compute_metadata
from .matmul_ogs_details.fast_contiguous import fast_contiguous
from .numerics_details.mxfp import SwizzlingType
from .specialize import specialize
Expand Down Expand Up @@ -343,8 +342,7 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data,
w = fast_contiguous(w.transpose(-1, -2)).transpose(-1, -2)
# preprocess routing information and ptr lookup table
M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0]
expt_data = compute_metadata(routing_data, M, opt_flags.block_m)
return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data
return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs


# ---------------------
Expand Down Expand Up @@ -596,20 +594,22 @@ def matmul_ogs(x, w, bias,

fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation
# pre-processing
x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data = apply_preprocessing_features(
x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs = apply_preprocessing_features(
x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features
)
if expt_data.buffer is not None:
assert expt_data.hist.shape[0] == n_expts_tot, "invalid expt_data"
assert expt_data.offs.shape[0] == n_expts_tot + 1, "invalid expt_data"
assert expt_data.blocks.shape[0] == grid_m, "invalid expt_data"
# matrix multiplication
n_cta = batch_size * grid_m * grid_n * opt_flags.split_k
n_cta = min(target_info.num_sms(), n_cta) if opt_flags.is_persistent else n_cta
flex = precision_config.flex_ctx
bias_stride = None if bias is None else bias.stride(0)
num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0]
kernels = get_kernels(epilogue.specs, fused_activation.specs)
expt_data = routing_data.expt_data
block_m = opt_flags.block_m
expt_hist = None if expt_data is None else expt_data.hist
expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1]
expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw
expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m]
(kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)](
flex.out_data.reinterpret(memory["output"]),
flex.out_data.reinterpret(out0), *out0.stride(),
Expand All @@ -628,7 +628,7 @@ def matmul_ogs(x, w, bias,
None if scatter_indx is None else scatter_indx.src_indx,
num_indx,
writeback_idxs, writeback_size,
expt_data.hist, expt_data.offs, expt_data.offs_sum, expt_data.blocks,
expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map,
batch_size, grid_m, grid_n,
out_alpha,
*fused_activation.fn_args, fused_activation.reduction_n,
Expand Down Expand Up @@ -659,7 +659,7 @@ def matmul_ogs(x, w, bias,
NUM_SMS = n_cta if opt_flags.is_persistent else 0,
**opt_flags.target_kernel_kwargs)
# post-processing
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs,
out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw,
num_indx, precision_config, routing_data,
postprocessing_features, memory, fused_postprocess_activation, epilogue)

Expand Down
Loading
Loading