diff --git a/python/triton_kernels/tests/test_routing.py b/python/triton_kernels/tests/test_routing.py index 8f5d5f650011..b9dc7cb880a8 100644 --- a/python/triton_kernels/tests/test_routing.py +++ b/python/triton_kernels/tests/test_routing.py @@ -2,81 +2,64 @@ import torch from triton_kernels.routing import routing, routing_torch from triton_kernels.testing import assert_close -from triton_kernels.matmul_ogs_details.metadata import compute_metadata from triton_kernels.testing import assert_equal -def init_data(n_tokens, n_expts_tot, dtype=torch.float16, device="cuda"): +def init_data(n_tokens, n_expts_tot, dtype=torch.float32, device="cuda"): logits = torch.randn((n_tokens, n_expts_tot), dtype=dtype, device=device, requires_grad=True) return logits -def ref_expt_data(routing_data, n_gates, block_m): - hist = routing_data.expt_hist - n_expts_tot = routing_data.n_expts_tot - blks = (hist + block_m - 1) // block_m # matmul blocks needed - tsum = torch.cumsum(hist, dim=0) # prefix sum of tokens - bsum = torch.cumsum(blks, dim=0) # prefix sum of blocks - # Get the max number of matmul blocks of size d_tile needed (and is launched with). - # This assumes the worst distribution of all experts with one token except for one that has the rest. - if n_gates <= n_expts_tot: - grid_m = n_gates - else: - # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 - # ceil_div(x, y): -(-x // y) - grid_m = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // block_m) - bloc_data = -torch.ones(grid_m, dtype=torch.int32) - # compute data required to drive ragged batch matmul - for e in range(n_expts_tot): - offset = bsum[e - 1] if e else 0 - for b in range(blks[e]): - bloc_data[offset + b] = (b << 16) + e - - expt_data = torch.zeros(n_expts_tot * 3 + 2 + grid_m, dtype=torch.int32, device=hist.device) - expt_data[:n_expts_tot] = routing_data.expt_hist - expt_data[n_expts_tot + 1:n_expts_tot * 2 + 1] = tsum - expt_data[n_expts_tot * 2 + 2:n_expts_tot * 3 + 2] = bsum - expt_data[n_expts_tot * 3 + 2:] = bloc_data - return expt_data - - -@pytest.mark.parametrize("n_tokens", [371, 255, 256, 8192, 1023, 1024]) -@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 4), (1500, 8)]) -@pytest.mark.parametrize("block_m", [64, 128]) +n_tokens = [(x, None) for x in [371, 255, 256, 4096, 1023, 1024]] +n_tokens += [(1152, 911)] + + +@pytest.mark.parametrize("n_tokens_pad, n_tokens_raw", n_tokens) +@pytest.mark.parametrize("n_expts_tot, n_expts_act", [(128, 32), (1500, 8)]) @pytest.mark.parametrize("use_expt_indx", [False, True]) -@pytest.mark.parametrize("renormalize", [True, False]) -def test_op(n_tokens, n_expts_tot, n_expts_act, renormalize, block_m, use_expt_indx, device): +@pytest.mark.parametrize("sm_first", [True, False]) +def test_op(n_tokens_pad, n_tokens_raw, n_expts_tot, n_expts_act, sm_first, use_expt_indx, device): torch.manual_seed(2) - tri_logits = init_data(n_tokens, n_expts_tot, device=device).detach() - ref_logits = tri_logits.clone() + if n_tokens_raw is None: + n_tokens_raw = n_tokens_pad + n_routing_rows = None + else: + n_routing_rows = torch.tensor([n_tokens_raw], dtype=torch.int32, device=device) + n_gates_raw = n_tokens_raw * n_expts_act + tri_logits = init_data(n_tokens_pad, n_expts_tot, device=device).detach() + tri_logits[n_tokens_raw:, :] = float("inf") # should not be used + tri_logits = tri_logits.requires_grad_(True) + ref_logits = tri_logits.clone().detach().requires_grad_(True) + if use_expt_indx: rand_idx = lambda: torch.randperm(n_expts_tot, device="cuda", dtype=torch.int64) - tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens)]) + tri_expt_indx = torch.stack([rand_idx()[:n_expts_act] for _ in range(n_tokens_pad)]) tri_expt_indx, _ = torch.sort(tri_expt_indx, dim=1) - ref_expt_indx = tri_expt_indx[:n_tokens] + tri_expt_indx[n_tokens_raw:] = -99999 # should not be used + ref_expt_indx = tri_expt_indx[:n_tokens_raw] else: tri_expt_indx = ref_expt_indx = None - if not renormalize: - tri_logits = torch.softmax(tri_logits, dim=-1) - ref_logits = torch.softmax(ref_logits, dim=-1) - ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, renormalize, ref_expt_indx) - tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, renormalize, tri_expt_indx) - ref_metadata = ref_expt_data(ref_routing_data, n_tokens * n_expts_act, block_m) - tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) + ref_routing_data, ref_gather, ref_scatter = routing_torch(ref_logits, n_expts_act, sm_first, ref_expt_indx, + n_rows=n_routing_rows) + tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act, sm_first, tri_expt_indx, + n_rows=n_routing_rows) def _assert_indx_equal(ref, tri): assert_equal(ref, tri[:len(ref)]) assert torch.all(tri[len(ref):] == -1) - # print((ref_routing_data.expt_hist != tri_routing_data.expt_hist).nonzero()) - # breakpoint() - assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal, 2e-2, 4e-3) + assert_close(ref_routing_data.gate_scal, tri_routing_data.gate_scal[:n_gates_raw], 2e-2, 4e-3) assert_equal(ref_routing_data.expt_hist, tri_routing_data.expt_hist) - assert_equal(ref_metadata[:n_expts_tot], tri_metadata.hist) - assert_equal(ref_metadata[n_expts_tot:2 * n_expts_tot + 1], tri_metadata.offs) - assert_equal(ref_metadata[3 * n_expts_tot + 1], tri_metadata.offs_sum) - assert_equal(ref_metadata[3 * n_expts_tot + 2:], tri_metadata.blocks) + ref_expt_data = ref_routing_data.expt_data + tri_expt_data = tri_routing_data.expt_data + assert_equal(ref_expt_data.hist, tri_expt_data.hist) + assert_equal(ref_expt_data.token_offs_raw, tri_expt_data.token_offs_raw) + assert len(ref_expt_data.token_offs_pad) == len(tri_expt_data.token_offs_pad) + assert len(ref_expt_data.block_pid_map) == len(tri_expt_data.block_pid_map) + for block_m in ref_expt_data.token_offs_pad.keys(): + assert_equal(ref_expt_data.token_offs_pad[block_m], tri_expt_data.token_offs_pad[block_m]) + assert_equal(ref_expt_data.block_pid_map[block_m], tri_expt_data.block_pid_map[block_m]) assert ref_routing_data.n_expts_tot == ref_routing_data.n_expts_tot assert ref_routing_data.n_expts_act == ref_routing_data.n_expts_act @@ -86,18 +69,22 @@ def _assert_indx_equal(ref, tri): _assert_indx_equal(ref_scatter.src_indx, tri_scatter.src_indx) _assert_indx_equal(ref_scatter.dst_indx, tri_scatter.dst_indx) + scales_grad = torch.randn_like(tri_routing_data.gate_scal) + ref_routing_data.gate_scal.backward(scales_grad[:n_gates_raw]) + tri_routing_data.gate_scal.backward(scales_grad) + + assert_close(ref_logits.grad[:n_tokens_raw], tri_logits.grad[:n_tokens_raw]) + def bench_routing(): import triton.profiler as proton n_tokens = 8192 - block_m = 128 n_expts_tot, n_expts_act = 128, 4 tri_logits = init_data(n_tokens, n_expts_tot) proton.start("routing") proton.activate() for i in range(100): tri_routing_data, tri_gather, tri_scatter = routing(tri_logits, n_expts_act) - tri_metadata = compute_metadata(tri_routing_data, n_tokens * n_expts_act, block_m) # noqa: F841 proton.finalize() try: import os diff --git a/python/triton_kernels/triton_kernels/bitmatrix.py b/python/triton_kernels/triton_kernels/bitmatrix.py deleted file mode 100644 index d81fee8ba282..000000000000 --- a/python/triton_kernels/triton_kernels/bitmatrix.py +++ /dev/null @@ -1,33 +0,0 @@ -from dataclasses import dataclass - -import torch - -from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows - - -@dataclass -class Bitmatrix: - """ - Represents a boolean matrix in a packed format where each element occupies - a single bit of memory. - - We use a Bitmatrix to represent the routing information, where each row - corresponds to a token and each column corresponds to an expert. - - S is either None or an all-zero array of size >= n_cols; we pass it along - with the actual bitmatrix to avoid having to launch a separate memset - kernel when we call Bitmatrix::sum(). - """ - - data: torch.Tensor - shape: tuple[int] - S: torch.tensor - - def sum(self, partials_block_size): - n_rows, n_cols = self.shape - dev = self.data.device - if self.S is None: - self.S = clear_sums(n_cols, dev) - out_ret = self.S[:n_cols] - self.S = None # throw error if we try to sum again - return sum_bitmatrix_rows(self, out_ret, partials_block_size) diff --git a/python/triton_kernels/triton_kernels/compaction.py b/python/triton_kernels/triton_kernels/compaction.py index a8b1cf5d19d4..a1bb8da5274d 100644 --- a/python/triton_kernels/triton_kernels/compaction.py +++ b/python/triton_kernels/triton_kernels/compaction.py @@ -1,6 +1,6 @@ import torch from .compaction_details._masked_compaction import _masked_compaction -from .bitmatrix import Bitmatrix +from .datastruct import Bitmatrix def compaction(yv, yi, bitmask, sentinel=-1): @@ -33,7 +33,7 @@ def compaction(yv, yi, bitmask, sentinel=-1): ret_yv = torch.empty_like(yv) ret_yi = torch.empty_like(yi) if isinstance(bitmask, Bitmatrix): - bitmask = bitmask.data + bitmask = bitmask.handle _masked_compaction[(n_rows, )]( yv, yi, bitmask, bitmask.stride(0), bitmask.stride(1), # inputs diff --git a/python/triton_kernels/triton_kernels/datastruct.py b/python/triton_kernels/triton_kernels/datastruct.py new file mode 100644 index 000000000000..87b82c395002 --- /dev/null +++ b/python/triton_kernels/triton_kernels/datastruct.py @@ -0,0 +1,48 @@ +import torch +from .reduction_details.reduce_bitmatrix import clear_sums, sum_bitmatrix_rows + + +class Tensor: + + def __init__(self, handle, shape_raw, shape_pad=None): + self.handle = handle + self.ndim = handle.ndim + self.dtype = handle.dtype + self.device = handle.device + self.shape_pad = handle.shape if shape_pad is None else shape_pad + self.shape_raw = shape_raw + + def stride(self, *args): + return self.handle.stride(*args) + + def data_ptr(self): + return self.handle.data_ptr() + + +class Bitmatrix(Tensor): + """ + Represents a boolean matrix in a packed format where each element occupies + a single bit of memory. + + _scratchpad is either None or an all-zero array of size >= shape[-1]; we pass it along + with the actual bitmatrix to avoid having to launch a separate memset + kernel when we call Bitmatrix::sum(). + """ + + _scratchpad: torch.Tensor + + def __init__(self, handle, shape_raw, scratchpad=None): + assert handle.ndim == 2 + shape_pad = [handle.shape[0], handle.shape[1] * 32] + super().__init__(handle, shape_raw, shape_pad) + assert self.dtype == torch.uint32 + self._scratchpad = scratchpad + + def sum(self, partials_block_size): + _, n_cols = self.shape_raw + dev = self.device + if self._scratchpad is None: + self._scratchpad = clear_sums(n_cols, dev) + out_ret = self._scratchpad[:n_cols] + self._scratchpad = None # throw error if we try to sum again + return sum_bitmatrix_rows(self, out_ret, partials_block_size, self.shape_raw[0]) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs.py b/python/triton_kernels/triton_kernels/matmul_ogs.py index f22361866a49..c4bbd8ff7e1f 100644 --- a/python/triton_kernels/triton_kernels/matmul_ogs.py +++ b/python/triton_kernels/triton_kernels/matmul_ogs.py @@ -13,7 +13,6 @@ from .matmul_ogs_details._p_matmul_ogs import _p_matmul_ogs, get_per_device_per_stream_alloc_fn from .matmul_ogs_details._finalize_matmul import _finalize_matmul from .matmul_ogs_details.opt_flags import make_opt_flags -from .matmul_ogs_details.metadata import compute_metadata from .matmul_ogs_details.fast_contiguous import fast_contiguous from .numerics_details.mxfp import SwizzlingType from .specialize import specialize @@ -343,8 +342,7 @@ def apply_preprocessing_features(x, w, gather_indx, scatter_indx, routing_data, w = fast_contiguous(w.transpose(-1, -2)).transpose(-1, -2) # preprocess routing information and ptr lookup table M = x.shape[1] if gather_indx is None else gather_indx.src_indx.shape[0] - expt_data = compute_metadata(routing_data, M, opt_flags.block_m) - return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data + return x, w, preprocessing_features.swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs # --------------------- @@ -596,13 +594,9 @@ def matmul_ogs(x, w, bias, fused_activation, fused_postprocess_activation = fused_postprocess_activation, fused_activation # pre-processing - x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs, expt_data = apply_preprocessing_features( + x, w, swap_xw, writeback_idxs, writeback_size, finalize_scatter_idxs = apply_preprocessing_features( x, w, gather_indx, scatter_indx, routing_data, opt_flags, preprocessing_features ) - if expt_data.buffer is not None: - assert expt_data.hist.shape[0] == n_expts_tot, "invalid expt_data" - assert expt_data.offs.shape[0] == n_expts_tot + 1, "invalid expt_data" - assert expt_data.blocks.shape[0] == grid_m, "invalid expt_data" # matrix multiplication n_cta = batch_size * grid_m * grid_n * opt_flags.split_k n_cta = min(target_info.num_sms(), n_cta) if opt_flags.is_persistent else n_cta @@ -610,6 +604,12 @@ def matmul_ogs(x, w, bias, bias_stride = None if bias is None else bias.stride(0) num_indx = None if scatter_indx is None else scatter_indx.src_indx.shape[0] kernels = get_kernels(epilogue.specs, fused_activation.specs) + expt_data = routing_data.expt_data + block_m = opt_flags.block_m + expt_hist = None if expt_data is None else expt_data.hist + expt_hist_sum = None if expt_data is None else expt_data.token_offs_pad[block_m][-1] + expt_token_offs_raw = None if expt_data is None else expt_data.token_offs_raw + expt_block_pid_map = None if expt_data is None else expt_data.block_pid_map[block_m] (kernels._p_matmul_ogs if opt_flags.is_persistent else kernels._matmul_ogs)[(n_cta,)]( flex.out_data.reinterpret(memory["output"]), flex.out_data.reinterpret(out0), *out0.stride(), @@ -628,7 +628,7 @@ def matmul_ogs(x, w, bias, None if scatter_indx is None else scatter_indx.src_indx, num_indx, writeback_idxs, writeback_size, - expt_data.hist, expt_data.offs, expt_data.offs_sum, expt_data.blocks, + expt_hist, expt_token_offs_raw, expt_hist_sum, expt_block_pid_map, batch_size, grid_m, grid_n, out_alpha, *fused_activation.fn_args, fused_activation.reduction_n, @@ -659,7 +659,7 @@ def matmul_ogs(x, w, bias, NUM_SMS = n_cta if opt_flags.is_persistent else 0, **opt_flags.target_kernel_kwargs) # post-processing - out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_data.offs, + out = apply_postprocessing_features(scatter_indx, finalize_scatter_idxs, opt_flags, expt_token_offs_raw, num_indx, precision_config, routing_data, postprocessing_features, memory, fused_postprocess_activation, epilogue) diff --git a/python/triton_kernels/triton_kernels/matmul_ogs_details/metadata.py b/python/triton_kernels/triton_kernels/matmul_ogs_details/metadata.py deleted file mode 100644 index f951b4d76806..000000000000 --- a/python/triton_kernels/triton_kernels/matmul_ogs_details/metadata.py +++ /dev/null @@ -1,115 +0,0 @@ -from dataclasses import dataclass -import torch -import triton -import triton.language as tl - - -@dataclass -class ExptData: - hist: torch.Tensor - offs: torch.Tensor - offs_sum: torch.Tensor - blocks: torch.Tensor - buffer: torch.Tensor - - -@triton.jit -def _matmul_metadata_memset(Hist, n_expts_tot, MDTokStarts, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr, - TILE_DIM: tl.constexpr, extra_block: tl.constexpr): - pid = tl.program_id(0) - - TileInfoOut = MDTileInfo + (pid - 1) * BLOCK + tl.arange(0, BLOCK) - - # if pid == 0 - initialize cumsums - if pid == 0: - x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty) - x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty) - - Tok_ptrs = MDTokStarts + tl.arange(0, BLOCK) - Tile_ptrs = MDTileStarts + tl.arange(0, BLOCK) - - for i in range(0, n_expts_tot, BLOCK): - offs_n = tl.arange(0, BLOCK) + i - if extra_block: - # we need an extra block at the end just to contain the final - # sum; this only happens if our total number of experts is an - # exact multiple of BLOCK, obviating the need for any masking - hist_tok = tl.load(Hist + offs_n) - else: - mask = offs_n < n_expts_tot - hist_tok = tl.load(Hist + offs_n, mask=mask, other=0) - hist_tile = tl.cdiv(hist_tok, TILE_DIM) - tok_starts = tl.cumsum(hist_tok, 0) + x_tok - x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty) - tile_starts = tl.cumsum(hist_tile, 0) + x_tile - x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty) - - tl.store(Tok_ptrs, tok_starts - hist_tok) - tl.store(Tile_ptrs, tile_starts - hist_tile) - - Tok_ptrs += BLOCK - Tile_ptrs += BLOCK - - if extra_block: - tl.store(Tok_ptrs, x_tok) - tl.store(Tile_ptrs, x_tile) - - else: - - tl.store(TileInfoOut, 0xffffffff) - - -@triton.jit -def _matmul_metadata_compute(Hist, MDTileStarts, MDTileInfo, BLOCK: tl.constexpr, TILE_DIM: tl.constexpr): - - expt_id = tl.program_id(0) - n_tokens = tl.load(Hist + expt_id) - n_blocks = tl.cdiv(n_tokens, TILE_DIM) - - tile_off = tl.load(MDTileStarts + expt_id) - MDTileInfo += tile_off - # MDTileInfo += tl.load(MDTilesStart + expt_id) - for block_off in range(0, n_blocks, BLOCK): - block_offs = block_off + tl.arange(0, BLOCK) - data = (block_offs << 16) + expt_id - tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks) - - -def compute_metadata(routing_data, n_rows, block_m): - if routing_data.expt_hist is None: - return ExptData(None, None, None, None, None) - MEMSET_BLOCK = 128 - HIST2_BLOCK_M = 512 - device = routing_data.expt_hist.device - n_expts_tot = routing_data.n_expts_tot - cdiv = triton.cdiv - if n_rows <= n_expts_tot: - grid_m = n_rows - else: - grid_m = n_expts_tot - 1 - ((n_expts_tot - n_rows - 1) // block_m) - - n_expts_pad = cdiv(n_expts_tot, MEMSET_BLOCK) * MEMSET_BLOCK - pad2 = cdiv(n_expts_tot + 1, MEMSET_BLOCK) * MEMSET_BLOCK - extra_block = (n_expts_pad != pad2) - pids = cdiv(grid_m, MEMSET_BLOCK) + 1 - - metadata_size = n_expts_pad + 2 * pad2 + MEMSET_BLOCK * (pids - 1) - - metadata = torch.empty(metadata_size, dtype=torch.int32, device=device) - - md_hist = routing_data.expt_hist[:n_expts_tot] - md_offs = metadata[:n_expts_tot + 1] - md_tile_starts = metadata[pad2:][:n_expts_tot + 1] - md_offs_sum = md_tile_starts[-1] - md_tile_infos = metadata[2 * pad2:][:grid_m] - _matmul_metadata_memset[(pids, )]( - routing_data.expt_hist, n_expts_tot, md_offs, md_tile_starts, md_tile_infos, - BLOCK=MEMSET_BLOCK, # optimization parameters - TILE_DIM=block_m, # constants - extra_block=extra_block, num_warps=1) - _matmul_metadata_compute[(n_expts_tot, )]( - routing_data.expt_hist, md_tile_starts, md_tile_infos, # outputs - BLOCK=HIST2_BLOCK_M, # optimization parameters - TILE_DIM=block_m, # constants - num_warps=4) - return ExptData(md_hist, md_offs, md_offs_sum, md_tile_infos, metadata) diff --git a/python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py b/python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py index 347917c95241..f9d0a984f29c 100644 --- a/python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py +++ b/python/triton_kernels/triton_kernels/reduction_details/reduce_bitmatrix.py @@ -50,7 +50,7 @@ def _sum_bitmatrix_memset(Ret, BLOCK: tl.constexpr): @triton.jit -def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix +def _sum_bitmatrix_rows(B, shape_bm, NRowsRaw, stride_bm: tl.constexpr, stride_bn: tl.constexpr, # input bitmatrix Ret, Partials, stride_pm: tl.constexpr, stride_pn, shape_pn, # outputs BLOCK_MM: tl.constexpr, BLOCK_M: tl.constexpr): @@ -60,7 +60,10 @@ def _sum_bitmatrix_rows(B, shape_bm, stride_bm: tl.constexpr, stride_bn: tl.cons pid_n = tl.program_id(1) offs_m = pid_m * BLOCK_MM + tl.arange(0, BLOCK_MM) offs_n = pid_n * 32 + tl.arange(0, 32) - bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < shape_bm, other=0) + n_rows = shape_bm + if NRowsRaw is not None: + n_rows = tl.load(NRowsRaw) + bits = tl.load(B + pid_n * stride_bn + offs_m * stride_bm, mask=offs_m < n_rows, other=0) bits = tl.reshape(bits, [TILE_SIZE, BLOCK_M]) ret = vpopc(bits) # [TILE_SIZE, 32] @@ -78,30 +81,30 @@ def clear_sums(n_cols, device, MEMSET_BLOCK=512): return out_ret -def sum_bitmatrix_rows(x, out_ret, partials_block_size=None): +def sum_bitmatrix_rows(x, out_ret, partials_block_size=None, n_rows_raw=None): assert partials_block_size is not None cdiv = triton.cdiv PARTIALS_BLOCK_M = partials_block_size - n_rows, n_cols = x.shape - assert out_ret.shape == (n_cols, ) + n_rows_pad, n_cols_raw = x.shape_pad[0], x.shape_raw[1] + assert out_ret.shape == (n_cols_raw, ) TILE_SIZE = 2 BLOCK_MM = PARTIALS_BLOCK_M * TILE_SIZE - pids_x = cdiv(n_rows, BLOCK_MM) - pids_y = cdiv(n_cols, 32) + pids_x = cdiv(n_rows_pad, BLOCK_MM) + pids_y = cdiv(n_cols_raw, 32) out_partials = torch.empty((pids_y * 32, pids_x * TILE_SIZE), device=out_ret.device, dtype=torch.int32) out_partials = torch.transpose(out_partials, 0, 1) # output tensors _sum_bitmatrix_rows[(pids_x, pids_y)]( - x.data, x.data.shape[0], x.data.stride(0), x.data.stride(1), # input + x.handle, x.shape_pad[0], x.shape_raw[0], x.stride(0), x.stride(1), # input out_ret, # output [final reduction] out_partials, out_partials.stride(0), out_partials.stride(1), out_partials.shape[1], # output [partial reductions] BLOCK_M=PARTIALS_BLOCK_M, BLOCK_MM=BLOCK_MM, # constants num_warps=8) - out_partials = out_partials[:cdiv(n_rows, PARTIALS_BLOCK_M), :n_cols] + out_partials = out_partials[:cdiv(n_rows_pad, PARTIALS_BLOCK_M), :n_cols_raw] return out_ret, out_partials diff --git a/python/triton_kernels/triton_kernels/routing.py b/python/triton_kernels/triton_kernels/routing.py index c3610ad4a0df..f59b926bdda6 100644 --- a/python/triton_kernels/triton_kernels/routing.py +++ b/python/triton_kernels/triton_kernels/routing.py @@ -5,6 +5,8 @@ from .routing_details._routing_compute import _routing_compute_indx_offs from .routing_details._routing_compute import _routing_compute_indx from .routing_details._routing_compute import _routing_clear_bitmatrix +from .routing_details._expt_data import _expt_data_memset +from .routing_details._expt_data import _expt_data_compute @dataclass @@ -29,12 +31,47 @@ class ScatterIndx: dst_indx: torch.Tensor +@dataclass +class ExptData: + # hist[i] is the number of tokens routed to expert i + hist: torch.Tensor + # token_offs_raw[i] is the offset of the first token routed + # to expert i in an expert-sorted array + token_offs_raw: torch.Tensor + # token_offs_pad[block][i] is the offset of the first token routed + # to expert i in an expert-sorted array, assuming histogram + # rounded to the next multiple of `block` + token_offs_pad: dict[int, torch.Tensor] + # block_id_map[block] contain one value for each `pid`` launched by + # the matrix multiplication kernel launched with BLOCK_M=block: + # - the value is -1 if the `pid` has no work to do + # - otherwise, the value is two int16 (packed as an int32) that + # correspond respectively to (1) the expert assigned to + # the tokens processed by this pid; (2) the block assigned to the + # tokens processed by this pid (think `pid_m` in a regular matmul) + # see `test_routing.py` for a reference implementation and more details + block_pid_map: dict[int, torch.Tensor] + + def __post_init__(self): + if self.hist is not None: + assert self.hist.dtype == torch.int32 + if self.token_offs_raw is not None: + assert self.token_offs_raw.dtype == torch.int32 + if self.token_offs_pad is not None: + for v in self.token_offs_pad.values(): + assert v.dtype == torch.int32 + if self.block_pid_map is not None: + for v in self.block_pid_map.values(): + assert v.dtype == torch.int32 + + @dataclass class RoutingData: gate_scal: torch.Tensor = field() expt_hist: torch.Tensor = field() n_expts_tot: int = field() n_expts_act: int = field() + expt_data: ExptData = None # Used to make perf annotation cleaner: when we use expert sharding, we can # use this to tell the "expected" number of local tokens per expert, because @@ -49,80 +86,242 @@ def n_blocks(self, n_rows, block_m): # -------------------------- -# Triton routing +# sort tokens by expert # -------------------------- -def routing(logits, n_expts_act, renormalize=True, expt_indx=None, simulated_ep=1): - from .topk import topk - from .compaction import compaction - cdiv = triton.cdiv - HIST_BLOCK_M = 64 - INDX_OFFS_BLOCK_M = 512 - MEMSET_BLOCK = 1024 - n_tokens, n_expts_tot = logits.shape - n_gates = n_tokens * n_expts_act - device = logits.device - expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, apply_softmax=renormalize, y_indx=expt_indx) - # mutate bitmatrix - if simulated_ep > 1: +class SortTokens(torch.autograd.Function): + + @staticmethod + def forward(ctx, expt_scal, expt_indx, bitmatrix): + HIST_BLOCK_M = 64 + INDX_OFFS_BLOCK_M = 512 + MEMSET_BLOCK = 1024 + cdiv = triton.cdiv + + device = expt_scal.device + dtype = expt_scal.dtype + n_tokens_raw, n_expts_tot = bitmatrix.shape_raw + n_tokens_pad, n_expts_act = expt_scal.shape + n_gates_pad = n_tokens_pad * n_expts_act + + hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M) + assert hist.dtype == torch.int32 + # scratchpad + expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) + combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device) + # output + topk_indx = combined_indx[:n_gates_pad] + gate_indx = combined_indx[n_gates_pad:] + gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device) + _routing_memset_indx[(cdiv(n_gates_pad * 2, MEMSET_BLOCK) + 1, )]( + combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, # + expt_offs, hist.shape[0], BLOCK_N=512 # + ) + _routing_compute_indx_offs[(n_expts_tot, )]( + expt_offs, partial_hist, # inputs + partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs + BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters + ) + indx_offs = partial_hist + _routing_compute_indx[(cdiv(n_tokens_pad, HIST_BLOCK_M), )]( + topk_indx, gate_indx, gate_scal, # outputs + expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), # inputs + n_tokens_pad, n_tokens_raw, # input shape + BLOCK_M=HIST_BLOCK_M, # tunable parameters + N_EXPTS_ACT=n_expts_act, # constants + num_warps=1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4 # + ) + ctx.n_tokens_raw = n_tokens_raw + ctx.n_tokens_pad = n_tokens_pad + ctx.n_expts_act = n_expts_act + ctx.save_for_backward(gate_indx) + return hist, topk_indx, gate_indx, gate_scal + + @staticmethod + def backward(ctx, _0, _1, _2, dgate_scal): + (gate_indx, ) = ctx.saved_tensors + dgate_scal = dgate_scal[gate_indx] + dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act) + return dgate_scal, None, None + + +def sort_tokens(expt_scal, expt_indx, bitmatrix): + return SortTokens.apply(expt_scal, expt_indx, bitmatrix) + + +# -------------------------- +# prune routing +# -------------------------- + + +class PruneRouting(torch.autograd.Function): + + @staticmethod + def forward(ctx, expt_scal, expt_indx, bitmatrix, simulated_ep): + from .compaction import compaction + n_tokens_pad = expt_scal.shape[0] + n_expts_tot = bitmatrix.shape_raw[-1] assert n_expts_tot % simulated_ep == 0 - _routing_clear_bitmatrix[(n_tokens, )]( - bitmatrix.data, - bitmatrix.data.stride(0), - bitmatrix.data.stride(1), - bitmatrix.data.shape[1], + _routing_clear_bitmatrix[(n_tokens_pad, )]( + bitmatrix.handle, + bitmatrix.handle.stride(0), + bitmatrix.handle.stride(1), + bitmatrix.handle.shape[1], n_expts_tot // simulated_ep, BLOCK_N=512, ) # perform compaction to update expt_scal / expt_indx expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) n_expts_tot = n_expts_tot // simulated_ep - bitmatrix.shape[-1] = n_expts_tot - # compute bitmatrix histogram - hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M) - # scratchpad - expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) - combined_indx = torch.empty(n_gates * 2, dtype=torch.int32, device=device) - # output - topk_indx = combined_indx[:n_gates] - gate_indx = combined_indx[n_gates:] - gate_scal = torch.empty(n_gates, dtype=logits.dtype, device=device) - _routing_memset_indx[(cdiv(n_gates * 2, MEMSET_BLOCK) + 1, )](combined_indx, n_gates * 2, -1, MEMSET_BLOCK, hist, - expt_offs, hist.shape[0], BLOCK_N=512) - _routing_compute_indx_offs[(n_expts_tot, )]( - expt_offs, partial_hist, # inputs - partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), # outputs - BLOCK_M=INDX_OFFS_BLOCK_M, # tunable parameters - ) - indx_offs = partial_hist - _routing_compute_indx[(cdiv(n_tokens, HIST_BLOCK_M), )]( - topk_indx, gate_indx, gate_scal, # outputs - expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), n_gates, # input - BLOCK_M=HIST_BLOCK_M, # tunable parameters - N_EXPTS_ACT=n_expts_act, # constants - num_warps=1 if HIST_BLOCK_M * n_expts_act // 32 < 4 else 4) + bitmatrix.shape_raw[-1] = n_expts_tot + return expt_scal, expt_indx, bitmatrix + + +def prune_routing(expt_scal, expt_indx, bitmatrix, simulated_ep): + return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, simulated_ep) + + +# -------------------------- +# expt_data +# -------------------------- + + +def log2_power_of_two(x): + assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two" + return x.bit_length() - 1 + + +def compute_expt_data(expt_hist, n_expts_tot, n_gates): + if expt_hist is None: + return ExptData(None, None, None, None) + MEMSET_BLOCK = 128 + HIST2_BLOCK_M = 512 + device = expt_hist.device + n_expts_tot = n_expts_tot + cdiv = triton.cdiv + # block_ms are all powers-of-two between 16 and 128 (inclusive) + block_m_log2_start = 4 + block_m_log2_end = 8 + block_m_num = block_m_log2_end - block_m_log2_start + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start) + # allocate memory + pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK + dtype = torch.int32 + token_offs_raw = torch.empty((n_expts_tot + 1, ), dtype=dtype, device=device) + token_offs_pad = torch.empty((block_m_num, pad(n_expts_tot + 1)), dtype=dtype, device=device) + block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device) + # compute outputs + token_offs_pad = token_offs_pad[:, :n_expts_tot + 1] + block_pid_map = block_pid_map[:, :max_n_tiles] + memset_grid = cdiv(block_pid_map.shape[1], MEMSET_BLOCK) + 1 + _expt_data_memset[(memset_grid, block_m_num)]( + expt_hist, n_expts_tot, token_offs_raw, # + token_offs_pad, token_offs_pad.stride(0), # + block_pid_map, block_pid_map.stride(0), # + block_m_log2_start, BLOCK=MEMSET_BLOCK, # optimization parameters + num_warps=1) + _expt_data_compute[(n_expts_tot, block_m_num)]( + expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), # outputs + block_m_log2_start, BLOCK=HIST2_BLOCK_M, # optimization parameters + num_warps=4) + # unpack into datastructure + token_offs_pad = {2**j: token_offs_pad[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))} + block_pid_map = {2**j: block_pid_map[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))} + return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map) + + +# -------------------------- +# routing +# -------------------------- + + +def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None): + from .topk import topk + if sm_first: + logits = torch.softmax(logits, dim=-1) + expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, # + apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) + # mutate bitmatrix + if simulated_ep > 1: + expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, simulated_ep) + hist, topk_indx, gate_indx, gate_scal = sort_tokens(expt_scal, expt_indx, bitmatrix) # pack the matmul data structure + n_expts_tot = logits.shape[-1] // simulated_ep gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx) scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx) - return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx + expt_data = compute_expt_data(hist, n_expts_tot, topk_indx.numel()) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx + + +# -------------------------- +# torch reference +# -------------------------- + + +def compute_expt_data_torch(hist, n_expts_tot, n_gates): + # offset for each experts + device = hist.device + token_offs_raw = torch.cumsum(hist, dim=0) + token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw)) + token_offs_raw = token_offs_raw.int() + # maximum number of tiles for all values of `block_m` considered + block_ms = [16, 32, 64, 128] + if n_gates <= n_expts_tot: + max_n_tiles = n_gates + else: + # ceil_div(n_gates - n_experts + 1, d_tile) + n_experts - 1 + # ceil_div(x, y): -(-x // y) + max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms)) + # fill up tile offset/infos for each block + token_offs_pad = dict() + block_pid_map = dict() + for block_m in [16, 32, 64, 128]: + n_tiles = (hist + block_m - 1) // block_m # matmul blocks needed + token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0) + token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m])) + token_offs_pad[block_m] = token_offs_pad[block_m].int() + # compute data required to drive ragged batch matmul + block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device) + for e in range(n_expts_tot): + offset = token_offs_pad[block_m][e] + for b in range(n_tiles[e]): + block_pid_map[block_m][offset + b] = (b << 16) + e + block_pid_map[block_m] = block_pid_map[block_m].int() + return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) + +def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None): + has_user_provided_indx = expt_indx is not None + n_gates_pad = logits.shape[0] * n_expts_act -def routing_torch(logits, n_expts_act, renormalize=True, expt_indx=None): + if n_rows is not None: + logits = logits[:n_rows, :] def topk(vals, k, expt_indx): # topk of experts - if expt_indx is None: - tk_idx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + if has_user_provided_indx: + tk_indx = expt_indx else: - tk_idx = expt_indx - tk_val = torch.take_along_dim(vals, tk_idx, dim=1) - return tk_val, tk_idx + tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] + tk_indx = tk_indx.long() + tk_val = torch.take_along_dim(vals, tk_indx, dim=1) + tk_indx = tk_indx.int() + return tk_val, tk_indx _, n_expts_tot = logits.shape + if sm_first: + logits = torch.softmax(logits, dim=-1) expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx) - if renormalize: + if not sm_first: expt_scal = torch.softmax(expt_scal, dim=-1) + # sort each token's selections by expert + if not has_user_provided_indx: + expt_indx, sort_indices = torch.sort(expt_indx, dim=1) + expt_scal = torch.gather(expt_scal, 1, sort_indices) # flatten topk data expt_scal = expt_scal.reshape(-1) expt_indx = expt_indx.reshape(-1).to(torch.int32) @@ -130,8 +329,10 @@ def topk(vals, k, expt_indx): topk_indx = torch.argsort(expt_indx, stable=True) gate_indx = torch.argsort(topk_indx, stable=True) gate_scal = expt_scal[topk_indx] - hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1) # histogram of tokens over experts + hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() # histogram of tokens over experts # pack the matmul data structure gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) - return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act), gather_indx, scatter_indx + # compute expt_data + expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad) + return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx diff --git a/python/triton_kernels/triton_kernels/routing_details/_expt_data.py b/python/triton_kernels/triton_kernels/routing_details/_expt_data.py new file mode 100644 index 000000000000..d8ab905dd98a --- /dev/null +++ b/python/triton_kernels/triton_kernels/routing_details/_expt_data.py @@ -0,0 +1,69 @@ +import triton +import triton.language as tl + + +@triton.jit +def _cdiv_pow2(n, log2_k): + return (n + ((1 << log2_k) - 1)) >> log2_k + + +@triton.jit +def _expt_data_memset(Hist, n_expts_tot, MDTokStarts, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_infos_stridem, + first_tile_dim_log2, BLOCK: tl.constexpr): + pid_n = tl.program_id(0) + pid_m = tl.program_id(1) + + tile_dim_log2 = first_tile_dim_log2 + pid_m + # if pid == 0 - initialize cumsums + if pid_n == 0: + MDTileStarts += pid_m * tile_starts_stridem + + x_tok = tl.zeros([BLOCK], dtype=MDTokStarts.dtype.element_ty) + x_tile = tl.zeros([BLOCK], dtype=MDTileStarts.dtype.element_ty) + + Tok_ptrs = MDTokStarts + tl.arange(0, BLOCK) + Tile_ptrs = MDTileStarts + tl.arange(0, BLOCK) + + for i in range(0, n_expts_tot + 1, BLOCK): + offs_n = tl.arange(0, BLOCK) + i + mask_n0 = offs_n < n_expts_tot + mask_n1 = offs_n < n_expts_tot + 1 + hist_tok = tl.load(Hist + offs_n, mask=mask_n0, other=0) + hist_tile = _cdiv_pow2(hist_tok, tile_dim_log2) + tok_starts = tl.cumsum(hist_tok, 0) + x_tok + x_tok += tl.sum(hist_tok, 0).to(MDTokStarts.dtype.element_ty) + tile_starts = tl.cumsum(hist_tile, 0) + x_tile + x_tile += tl.sum(hist_tile, 0).to(MDTileStarts.dtype.element_ty) + + tl.store(Tok_ptrs, tok_starts - hist_tok, mask=mask_n1) + tl.store(Tile_ptrs, tile_starts - hist_tile, mask=mask_n1) + + Tok_ptrs += BLOCK + Tile_ptrs += BLOCK + + else: + MDTileInfo += pid_m * tile_infos_stridem + TileInfoOut = MDTileInfo + (pid_n - 1) * BLOCK + tl.arange(0, BLOCK) + tl.store(TileInfoOut, 0xffffffff) + + +@triton.jit +def _expt_data_compute(Hist, MDTileStarts, tile_starts_stridem, MDTileInfo, tile_info_stridem, first_tile_dim_log2, + BLOCK: tl.constexpr): + expt_id = tl.program_id(0) + buff_id = tl.program_id(1) + + MDTileStarts += buff_id * tile_starts_stridem + MDTileInfo += buff_id * tile_info_stridem + + n_tokens = tl.load(Hist + expt_id) + tile_dim_log2 = first_tile_dim_log2 + buff_id + n_blocks = _cdiv_pow2(n_tokens, tile_dim_log2) + + tile_off = tl.load(MDTileStarts + expt_id) + MDTileInfo += tile_off + # MDTileInfo += tl.load(MDTilesStart + expt_id) + for block_off in range(0, n_blocks, BLOCK): + block_offs = block_off + tl.arange(0, BLOCK) + data = (block_offs << 16) + expt_id + tl.store(MDTileInfo + block_offs, data, mask=block_offs < n_blocks) diff --git a/python/triton_kernels/triton_kernels/routing_details/_routing_compute.py b/python/triton_kernels/triton_kernels/routing_details/_routing_compute.py index 4d389dd6f253..f05d6d607614 100644 --- a/python/triton_kernels/triton_kernels/routing_details/_routing_compute.py +++ b/python/triton_kernels/triton_kernels/routing_details/_routing_compute.py @@ -48,9 +48,13 @@ def _keyed_add(x, y): @triton.jit def _routing_compute_indx(GatherIndx, ScatterIndx, GateScal, ExptScal, ExptIndx, PartialOffs, stride_pm, stride_pn, - n_gates, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr): + n_tokens_pad, NTokensRaw, BLOCK_M: tl.constexpr, N_EXPTS_ACT: tl.constexpr): pid_m = tl.program_id(0) + n_tokens = n_tokens_pad + if NTokensRaw is not None: + n_tokens = tl.load(NTokensRaw) + n_gates = n_tokens * N_EXPTS_ACT tl.static_assert(N_EXPTS_ACT * BLOCK_M <= 32768) diff --git a/python/triton_kernels/triton_kernels/swiglu.py b/python/triton_kernels/triton_kernels/swiglu.py index aa067c93ac06..606a03500a39 100644 --- a/python/triton_kernels/triton_kernels/swiglu.py +++ b/python/triton_kernels/triton_kernels/swiglu.py @@ -4,7 +4,6 @@ import triton from .swiglu_details._swiglu import _swiglu, _swiglu_fn from triton_kernels import target_info -from .matmul_ogs_details.metadata import compute_metadata @dataclass(frozen=True) @@ -53,7 +52,7 @@ def forward(ctx, a, alpha, precision_config, routing_data): grid = (min(M_BLOCKS * N_BLOCKS, 4 * num_sms), ) n_tokens = None if routing_data is not None: - n_tokens = compute_metadata(routing_data, M, BLOCK_M).offs[routing_data.n_expts_tot] + n_tokens = routing_data.expt_data.token_offs_raw[routing_data.n_expts_tot] _swiglu[grid]( flex_ctx.out_data.reinterpret(out), flex_ctx.out_data.expected_scale, diff --git a/python/triton_kernels/triton_kernels/topk.py b/python/triton_kernels/triton_kernels/topk.py index 718c6cf8d6fa..9b98f339659c 100644 --- a/python/triton_kernels/triton_kernels/topk.py +++ b/python/triton_kernels/triton_kernels/topk.py @@ -1,43 +1,88 @@ import torch -from .topk_details._topk import _topk -from .bitmatrix import Bitmatrix +import triton +from triton_kernels.topk_details._topk_forward import _topk_forward +from triton_kernels.topk_details._topk_backward import _topk_backward +from triton_kernels.datastruct import Bitmatrix +from triton_kernels.datastruct import Tensor -def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None): +def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): + if not isinstance(x, Tensor): + x = Tensor(x, [n_rows, None]) cdiv = lambda a, b: (a + b - 1) // b BLOCK_M = 32 BLOCK_N = 32 BLOCK_S = 128 assert x.ndim == 2 - assert x.shape[-1] < 32768 + assert x.shape_pad[-1] < 32768 assert dim == 1 assert return_bitmatrix - n_rows, n_cols = x.shape + n_rows_pad, n_cols = x.shape_pad + n_rows_raw = x.shape_raw[0] dev = x.device n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N n_cols_words = n_cols_pad // 32 # scratchpad tensors # NOTE: these are not returned - y_vals = torch.empty((n_rows, k), dtype=x.dtype, device=dev) + y_vals = torch.empty((n_rows_pad, k), dtype=x.dtype, device=dev) if y_indx is not None: use_provided_indx = True else: - y_indx = torch.empty((n_rows, k), dtype=torch.int16, device=dev) + y_indx = torch.empty((n_rows_pad, k), dtype=torch.int16, device=dev) use_provided_indx = False # create bitmatrix in transposed memory layout: - bitmatrix = torch.empty((n_cols_words, cdiv(n_rows, 32) * 32), dtype=torch.uint32, device=dev) - bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows] + bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_pad, 32) * 32), dtype=torch.uint32, device=dev) + bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_pad] s_blocks = cdiv(n_cols, BLOCK_S) s_cols = s_blocks * BLOCK_S - S = torch.empty((s_cols, ), dtype=torch.int32, device=dev) - pids = max(cdiv(n_rows, BLOCK_M), s_blocks) - _topk[(pids, )]( + scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev) + pids = max(cdiv(n_rows_pad, BLOCK_M), s_blocks) + _topk_forward[(pids, )]( x, x.stride(0), # inputs - y_vals, y_indx, y_vals.stride(0), # output [topk] - use_provided_indx, bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix] - n_rows, n_cols, # shapes - S, BLOCK_S, s_blocks, # thing to memset to zero + y_vals, y_indx, y_vals.stride(0), use_provided_indx, # output [topk] + bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), # output [bitmatrix] + n_rows_pad, n_rows_raw, n_cols, # shapes + scratchpad, BLOCK_S, s_blocks, # thing to memset to zero BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, # tunable parameter - N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants - APPLY_SOFTMAX=apply_softmax) - return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows, n_cols], S) + APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, # constants + ) + return y_vals, y_indx, Bitmatrix(bitmatrix, [n_rows_raw, n_cols], scratchpad) + + +def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax): + assert dy_vals.shape[-1] == k + n_expts_pad = triton.next_power_of_2(x.shape[-1]) + dx = torch.empty_like(x) + _topk_backward[(dy_vals.shape[0], )]( + y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), # inputs + dx, # outputs + dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k, + N_EXPTS_PAD=n_expts_pad) + return dx + + +class TopK(torch.autograd.Function): + + @staticmethod + def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows): + y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) + ctx.save_for_backward(x, y_indx) + ctx.apply_softmax = apply_softmax + ctx.k = k + ctx.n_rows = n_rows + return y_vals, y_indx, bitmatrix + + @staticmethod + def backward(ctx, dy_vals, _0, _1): + x, y_indx = ctx.saved_tensors + dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax) + return dx, None, None, None, None, None, None + + +def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): + ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) + return ret + + +# x = torch.randn((32, 32), dtype=torch.float16, device="cuda") +# print(topk(x, 4)) diff --git a/python/triton_kernels/triton_kernels/topk_details/__init__.py b/python/triton_kernels/triton_kernels/topk_details/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py b/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py new file mode 100644 index 000000000000..eebe48177154 --- /dev/null +++ b/python/triton_kernels/triton_kernels/topk_details/_topk_backward.py @@ -0,0 +1,51 @@ +import triton +import triton.language as tl + + +@triton.jit +def _topk_backward( + Yi, + stride_ym, # topk indices + DY, + stride_dym, # output gradient values + X, + stride_xm, # input values + DX, + stride_dxm, # input gradient values + n_rows, + NRows, + n_expts_tot, + APPLY_SOFTMAX: tl.constexpr, + N_EXPTS_ACT: tl.constexpr, + N_EXPTS_PAD: tl.constexpr, +): + pid_m = tl.program_id(0) + if NRows is not None: + n_rows = tl.load(NRows) + if pid_m >= n_rows: + return + Yi += pid_m * stride_ym + DY += pid_m * stride_dym + X += pid_m * stride_xm + DX += pid_m * stride_dxm + # -- + offs_xn = tl.arange(0, N_EXPTS_PAD) + offs_yn = tl.arange(0, N_EXPTS_ACT) + mask_xn = offs_xn < n_expts_tot + # recompute softmax + y_indx = tl.load(Yi + offs_yn) + x = tl.load(X + y_indx) + x = x.to(tl.float32) + y = tl.softmax(x) + # compute input-gradient + dy = tl.load(DY + offs_yn) + dy = dy.to(tl.float32) + s = tl.sum(y * dy, 0) + # write-back input gradient + tl.store(DX + offs_xn, 0, mask=mask_xn) + tl.debug_barrier() + if APPLY_SOFTMAX: + dx = y * (dy - s) + else: + dx = dy + tl.store(DX + y_indx, dx) diff --git a/python/triton_kernels/triton_kernels/topk_details/_topk.py b/python/triton_kernels/triton_kernels/topk_details/_topk_forward.py similarity index 61% rename from python/triton_kernels/triton_kernels/topk_details/_topk.py rename to python/triton_kernels/triton_kernels/topk_details/_topk_forward.py index a1852ec1980d..ca17926e6ed3 100644 --- a/python/triton_kernels/triton_kernels/topk_details/_topk.py +++ b/python/triton_kernels/triton_kernels/topk_details/_topk_forward.py @@ -3,21 +3,28 @@ @triton.jit -def fpval_to_key(x): +def get_topmask_and_fullmask(x): tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits") tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 + tm_arr = tl.full(x.shape, tm, dtype=x.dtype) + fm_arr = tl.full(x.shape, fm, dtype=x.dtype) + return tm_arr, fm_arr + + +@triton.jit +def fpval_to_key(x): + tm, fm = get_topmask_and_fullmask(x) return x ^ tl.where((x & tm) != 0, fm, tm) @triton.jit def key_to_fpval(x): - tl.static_assert(x.dtype.is_int_unsigned(), "floating-point value must be passed as bits") - tm: tl.constexpr = 1 << (-1 + x.dtype.primitive_bitwidth) - fm: tl.constexpr = (1 << x.dtype.primitive_bitwidth) - 1 + tm, fm = get_topmask_and_fullmask(x) return x ^ tl.where((x & tm) == 0, fm, tm) +# stable top-k tie-breaks to value with smaller index @triton.jit def indx_to_key(indx, N_EXPTS_PAD: tl.constexpr): return N_EXPTS_PAD - indx @@ -33,8 +40,14 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co BLOCK_N: tl.constexpr): x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}") - x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}") - x_dbtype: tl.constexpr = tl.dtype(f"fp{2*x_nbits}") + if x_nbits < 16: + # this ensures that we leave at least 16 bits for expert index + # even if the input dtype is smaller than 16 bits: + y_nbits: tl.constexpr = 32 + else: + y_nbits: tl.constexpr = x_nbits * 2 + x_ultype: tl.constexpr = tl.dtype(f"uint{y_nbits}") + x_dtype: tl.constexpr = X.dtype.element_ty # subtract 1 from loop iterations because we peel the first (masked) iteration: loop_iterations: tl.constexpr = N_EXPTS_PAD // BLOCK_N - 1 @@ -45,7 +58,7 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co X_ptrs = X + offs_m[:, None] * stride_xm + offs_x_n[None, :] x = tl.load(X_ptrs, mask=(mask_m & mask_n), other=float("-inf")) x = fpval_to_key(x.to(x_utype, bitcast=True)) - x = (x.to(x_ultype) << x_nbits) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] acc = tl.topk(x, N_EXPTS_ACT, dim=1) # subsequent iterations: @@ -55,27 +68,37 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co offs_x_n -= BLOCK_N x = tl.load(X_ptrs, mask=mask_m, other=float("-inf")) x = fpval_to_key(x.to(x_utype, bitcast=True)) - x = (x.to(x_ultype) << x_nbits) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] + x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :] acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1)) + # rotate expert index into upper 16 bits: + # 0000vvvvvvvviiii --> iiii0000vvvvvvvv + acc = (acc << (y_nbits - 16)) | (acc >> 16) + # sort in ascending order of expert (descending order of key) acc = tl.sort(acc, dim=1, descending=True) - acc_values = key_to_fpval((acc >> x_nbits).to(x_utype)) - acc_indices = key_to_indx(acc & 0x0000FFFF, N_EXPTS_PAD) - acc = (acc_values.to(x_ultype) << x_nbits) | acc_indices - acc = acc.to(x_dbtype, bitcast=True) + # iiii0000vvvvvvvv --> 0000iiii: + y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32) + y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD) + # iiii0000vvvvvvvv --> vvvvvvvv: + y_values_raw = acc.to(x_utype) + y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True) - return acc + return y_values, y_indices @triton.jit -def _topk(X, stride_xm, # inputs - Yv, Yi, stride_ym, # topk values/indices - USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, n_rows, # bitmatrix - n_expts_tot, S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset - BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr, - APPLY_SOFTMAX: tl.constexpr): +def _topk_forward(X, stride_xm, # inputs + Yv, Yi, stride_ym, # topk values/indices + USE_PROVIDED_INDX: tl.constexpr, Bits, stride_rm: tl.constexpr, stride_rn: tl.constexpr, # bitmatrix + n_rows_pad, NRowsRaw, n_expts_tot, # shape + S, BLOCK_S: tl.constexpr, s_blocks, # thing to memset + APPLY_SOFTMAX: tl.constexpr, # constant + BLOCK_M: tl.constexpr, N_EXPTS_PAD: tl.constexpr, N_EXPTS_ACT: tl.constexpr, BLOCK_N: tl.constexpr): pid = tl.program_id(0) + n_rows = n_rows_pad + if NRowsRaw is not None: + n_rows = tl.load(NRowsRaw) if pid < s_blocks: tl.store(S + BLOCK_S * pid + tl.arange(0, BLOCK_S), tl.zeros([BLOCK_S], tl.int32)) @@ -87,9 +110,6 @@ def _topk(X, stride_xm, # inputs tl.static_assert(BLOCK_N % 32 == 0) tl.static_assert(N_EXPTS_PAD % BLOCK_N == 0) x_dtype: tl.constexpr = X.dtype.element_ty - x_nbits: tl.constexpr = X.dtype.element_ty.primitive_bitwidth - x_utype: tl.constexpr = tl.dtype(f"uint{x_nbits}") - x_ultype: tl.constexpr = tl.dtype(f"uint{2*x_nbits}") # load logits offs_m = pid * BLOCK_M + tl.arange(0, BLOCK_M) @@ -101,11 +121,10 @@ def _topk(X, stride_xm, # inputs Xv_ptrs = X + offs_m[:, None] * stride_xm + y_indices y_values = tl.load(Xv_ptrs, mask=mask_m) else: - y = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N) - y = y.to(x_ultype, bitcast=True) - y_indices = y & 0x0000FFFF - y_values = (y >> x_nbits).to(x_utype).to(x_dtype, bitcast=True) + y_values, y_indices = streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, # + N_EXPTS_PAD, N_EXPTS_ACT, BLOCK_N) + # normalize selected values if APPLY_SOFTMAX: y_values = tl.softmax(y_values.to(tl.float32), dim=1, keep_dims=True).to(x_dtype)