diff --git a/README.md b/README.md index c891ad4..6e8565d 100644 --- a/README.md +++ b/README.md @@ -2,12 +2,14 @@ [![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE) [![Python](https://img.shields.io/badge/python-≥3.10-blue.svg)](pyproject.toml) -[![GPU](https://img.shields.io/badge/NVIDIA-SM100-76b900.svg)](#requirements) +[![GPU](https://img.shields.io/badge/NVIDIA-SM100%20%2F%20SM103-76b900.svg)](#requirements) [![Stack: CuTe-DSL + Cuda](https://img.shields.io/badge/stack-CuTe--DSL%20%2B%20Cuda-purple.svg)](#stacks) -**MSA** (`fmha_sm100`) ships dense FlashAttention and sparse top-k attention -kernels for **NVIDIA SM100**. Two JIT-compiled stacks -share one Python package: +**MSA** (`fmha_sm100`, plus the companion `fmha_sm12x` namespace) ships +dense FlashAttention and sparse top-k attention kernels for **NVIDIA SM100 / +SM103**. Those kernels use SM100-only tcgen05/TMEM instructions, so SM120 / +SM121 (GB10) is served by the separate `fmha_sm12x` package rather than +aliasing the SM100 path: ![MSA architecture](docs/architecture.png) @@ -25,7 +27,7 @@ share one Python package: ## Requirements -- **GPU**: NVIDIA SM100. +- **GPU**: NVIDIA SM100 / SM103 for the `fmha_sm100` kernels; SM120 / SM121 (GB10) is served by the `fmha_sm12x` package (portable CUDA helpers + a Triton block-sparse prefill kernel, with Torch reference fallbacks). - **Toolchain**: CUDA Toolkit with `nvcc` on `PATH` (or `CUDA_HOME` / `CUDA_PATH` set). - **Python**: ≥ 3.10. - **OS**: Linux x86_64 (aarch64 untested; JIT builds may need small Makefile edits on WSL). @@ -33,8 +35,8 @@ share one Python package: Quick sanity check before installing: ```bash -nvcc --version # expect ≥ 12.x -nvidia-smi --query-gpu=compute_cap --format=csv | grep "10.0" # confirm SM100 +nvcc --version # expect ≥ 12.x for SM100, CUDA 13.x for SM120/SM121 +nvidia-smi --query-gpu=compute_cap --format=csv | grep -E "10\.(0|3)|12\.(0|1)" python -c "import sys; print(sys.version_info[:2])" # ≥ (3, 10) ``` @@ -56,6 +58,26 @@ This pulls in the CuTe-DSL stack via `nvidia-cutlass-dsl` and `quack-kernels`; the csrc kernels are JIT-compiled at first import from sources shipped inside the package. +By default, csrc JIT builds preserve the original SM100/SM103 targets. The +`fmha_sm12x` package is a parallel SM120/SM121 namespace: it ships real +portable CUDA helpers (the k2q CSR builder, paged-decode split-KV scheduler, +and top-k selector), a semi-optimized Triton kernel for block-sparse prefill +attention (BF16/FP16, plus FP8 E4M3 and NVFP4 K/V staged to BF16), and +correctness-first Torch references for the rest (dense attention, FP4 block +scoring, paged decode). The +Triton path is optional — Triton ships transitively with `torch` on Linux, is +imported lazily, and falls back to the Torch reference when unavailable, so it +is not a declared dependency. Do not compile the existing `fmha_sm100` kernels +for GB10 / SM121; they rely on SM100-only tcgen05/TMEM operations. For SM12x +development, set the target before importing compiled SM12x kernel modules so +caches are partitioned by architecture: + +```bash +export MSA_CUDA_ARCH=sm_121 +# For custom multi-target builds, override the full nvcc gencode list: +# export MSA_NVCC_GENCODES='-gencode=arch=compute_120,code=sm_120 -gencode=arch=compute_121,code=sm_121' +``` + ## Verify Run a small CUDA smoke test. **The first run JIT-compiles `sparse_topk_select`, @@ -160,11 +182,12 @@ python/fmha_sm100/ Python package api.py fmha_sm100 / fmha_sm100_plan / sparse_topk_select jit.py Runtime JIT (nvcc + ninja) for the csrc stack sparse.py Lazy shim that loads the cute/ stack - sparse_fmha_adapter.py Bridge: fmha_sm100 API → sparse_atten_func + sparse_fmha_adapter.py Bridge: fmha_sm100 API -> sparse_atten_func csrc/ CUDA kernels + Jinja templates (JIT-compiled) include/ Vendored FlashInfer / CUTLASS-derived / TRT-LLM headers cutlass/ NVIDIA CUTLASS git submodule (include/ + tools/util/include/) cute/ CuTe-DSL sparse attention (loaded via sys.path) +fmha_sm12x/ SM120/SM121 attention/indexer/decode namespace (Triton + Torch) tests/ Correctness tests smoke/ integration/ regression/ scripts/ Warmup + cache-management helpers @@ -184,6 +207,21 @@ benchmarks/ bench_sparse_attention_ops.py site to the sparse backend for prefill paths; useful when you already drive the dense kernel and want a one-line swap to sparse. +### SM12x status + +The SM100 tcgen05/TMEM kernels stay isolated; `fmha_sm12x` exposes independent +SM120/SM121-safe routes that mirror the `fmha_sm100` public surface. The +portable helpers are real optimized CUDA kernels: the q2k→k2q CSR builder, the +paged-decode split-KV scheduler, and the `sparse_topk_select` indexer. Block- +sparse prefill attention — including the FP8 E4M3 and NVFP4 K/V variants, which +are staged/dequantized to BF16 — runs on a semi-optimized Triton kernel for +dense KV, used by default when Triton is importable and falling back to the +Torch reference otherwise. Dense FMHA, the FP4 indexer block scores, and paged +decode (BF16, or FP8 staged to BF16) remain correctness-first Torch +references. These routes cover the full API and are validated on GB10; +matching SM100's fused-kernel throughput would still need dedicated SM12x +CUTLASS/CuTe kernels. + ## Third-party licenses `fmha_sm100` bundles, derives from, or depends on the third-party components diff --git a/pyproject.toml b/pyproject.toml index a4a4606..1d58ff1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,10 +34,15 @@ dependencies = [ # which adds that directory to sys.path and re-exports the public sparse API. [tool.setuptools] package-dir = { "" = "python" } -# The cute/ sparse sources are loaded via sys.path.insert at runtime by -# fmha_sm100/sparse.py (not imported as a submodule), so they are shipped as -# package data. -packages = ["fmha_sm100"] + +# Auto-discover the packages. fmha_sm12x's cute/src loaders are real importable +# subpackages, so the glob pulls them all in without listing each one. +# fmha_sm100 stays a single package on purpose: its cute/ + cutlass/ trees are +# loaded via sys.path at runtime and ship as package data (below), so they must +# NOT be discovered as importable packages. +[tool.setuptools.packages.find] +where = ["python"] +include = ["fmha_sm100", "fmha_sm12x*", "minimax_msa"] [tool.setuptools.package-data] fmha_sm100 = [ @@ -54,3 +59,11 @@ fmha_sm100 = [ "cutlass/include/**/*", "cutlass/tools/util/include/**/*", ] + +fmha_sm12x = [ + "cute/**/*.py", + "cute/**/*.cu", + "cute/**/*.cuh", + "cute/**/*.h", + "cute/**/*.hpp", +] diff --git a/python/fmha_sm12x/__init__.py b/python/fmha_sm12x/__init__.py new file mode 100644 index 0000000..9effaf1 --- /dev/null +++ b/python/fmha_sm12x/__init__.py @@ -0,0 +1,82 @@ +# SPDX-License-Identifier: MIT + +"""SM120/SM121 package facade for MiniMax Sparse Attention.""" + +from __future__ import annotations + +from minimax_msa.arch import ( + CudaArch, + cpp_extension_arch_flag, + cuda_arch_cache_suffix, + nvcc_gencode_flags, + selected_cuda_arch, +) + +_API_EXPORTS = frozenset( + { + "Sm12xPlan", + "fmha_sm12x", + "fmha_sm12x_plan", + "sparse_topk_select", + } +) +_SPARSE_EXPORTS = frozenset( + { + "Nvfp4QuantizedTensor", + "SparseDecodePagedAttentionWrapper", + "SparseK2qCsrBuilderSm12x", + "build_k2q_csr", + "dequantize_nvfp4_128x4_to_bf16", + "fp4_indexer_block_scores", + "nvfp4_global_scale_from_amax", + "nvfp4_scale_128x4_offset", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "swizzle_nvfp4_scale_to_128x4", + } +) + +__all__ = [ + "CudaArch", + "Nvfp4QuantizedTensor", + "Sm12xPlan", + "SparseDecodePagedAttentionWrapper", + "SparseK2qCsrBuilderSm12x", + "build_k2q_csr", + "cpp_extension_arch_flag", + "cuda_arch_cache_suffix", + "dequantize_nvfp4_128x4_to_bf16", + "fmha_sm12x", + "fmha_sm12x_plan", + "fp4_indexer_block_scores", + "nvcc_gencode_flags", + "nvfp4_global_scale_from_amax", + "nvfp4_scale_128x4_offset", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "selected_cuda_arch", + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "sparse_topk_select", + "swizzle_nvfp4_scale_to_128x4", +] + + +def __getattr__(name: str): + if name in _API_EXPORTS: + from . import api as _api + + return getattr(_api, name) + if name in _SPARSE_EXPORTS: + from . import sparse as _sparse + + return getattr(_sparse, name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +def __dir__() -> list[str]: + return sorted({*globals(), *__all__}) diff --git a/python/fmha_sm12x/_decode.py b/python/fmha_sm12x/_decode.py new file mode 100644 index 0000000..61e50d2 --- /dev/null +++ b/python/fmha_sm12x/_decode.py @@ -0,0 +1,131 @@ +# SPDX-License-Identifier: MIT + +"""SM12x paged decode reference implementation.""" + +from __future__ import annotations + +import torch + +from ._lse import run_lse +from .api import fmha_sm12x, fmha_sm12x_plan + + +def _compact_page_table(page_table: torch.Tensor, seqused_k: torch.Tensor, page_size: int) -> torch.Tensor: + pages: list[torch.Tensor] = [] + lengths = seqused_k.to("cpu", dtype=torch.int64, non_blocking=False).tolist() + for batch, kv_len in enumerate(lengths): + page_count = (int(kv_len) + int(page_size) - 1) // int(page_size) + if page_count > 0: + pages.append(page_table[batch, :page_count]) + if not pages: + return torch.empty((0,), dtype=torch.int32, device=page_table.device) + return torch.cat(pages).contiguous() + + +def sparse_decode_atten_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + q2k_indices: torch.Tensor | None = None, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = True, + softmax_scale: float | None = None, + return_softmax_lse: bool = False, + **_kwargs, +): + """Run SM12x paged decode via the dequantized Torch reference path.""" + + if q.ndim != 3 or k.ndim != 4 or v.ndim != 4: + raise ValueError("decode expects q [B*S,Hq,D] and paged k/v [P,Hkv,page,D]") + # FP8 E4M3 K/V cache (and FP8 Q) is staged to BF16, matching the SM100 + # decode path; the reference then runs in BF16. + q = q.to(torch.bfloat16) if q.dtype == torch.float8_e4m3fn else q + k = k.to(torch.bfloat16) if k.dtype == torch.float8_e4m3fn else k + v = v.to(torch.bfloat16) if v.dtype == torch.float8_e4m3fn else v + if page_table.dtype != torch.int32 or seqused_k.dtype != torch.int32: + raise TypeError("page_table and seqused_k must be int32") + if page_table.device != q.device or seqused_k.device != q.device: + raise ValueError("decode metadata must be on q.device") + batch = int(page_table.shape[0]) + if int(q.shape[0]) != batch * int(seqlen_q): + raise ValueError("q.shape[0] must equal batch * seqlen_q") + if int(k.shape[2]) != int(blk_kv) or k.shape != v.shape: + raise ValueError("paged k/v shapes must match and use blk_kv page size") + if max_seqlen_k <= 0: + raise ValueError("max_seqlen_k must be positive") + qo_lens = torch.full((batch,), int(seqlen_q), dtype=torch.int32, device=q.device) + kv_lens = seqused_k.to(dtype=torch.int32) + kv_indices = _compact_page_table(page_table.contiguous(), seqused_k.contiguous(), int(blk_kv)) + plan = fmha_sm12x_plan( + qo_lens, kv_lens, int(q.shape[1]), int(k.shape[1]), page_size=int(blk_kv), + causal=bool(causal), + ) + block_indexes = None + if q2k_indices is not None: + if q2k_indices.dtype != torch.int32 or q2k_indices.ndim != 3: + raise ValueError("q2k_indices must be int32 with shape [Hkv, total_q, topK]") + block_indexes = q2k_indices.permute(1, 0, 2).contiguous() + out, _ = fmha_sm12x( + q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes, + sm_scale=softmax_scale, + ) + if return_softmax_lse: + lse = run_lse( + q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes, + sm_scale=softmax_scale, + ) + return out, lse + return out + + +class SparseDecodePagedAttentionWrapper: + """Plan/run wrapper matching the SM100 decode surface for SM12x.""" + + def __init__(self, *, blk_kv: int = 128, causal: bool = True) -> None: + self.blk_kv = int(blk_kv) + self.causal = bool(causal) + self.page_table: torch.Tensor | None = None + self.seqused_k: torch.Tensor | None = None + self.q2k_indices: torch.Tensor | None = None + self.seqlen_q: int | None = None + self.max_seqlen_k: int | None = None + + def plan( + self, + *, + page_table: torch.Tensor, + seqused_k: torch.Tensor, + seqlen_q: int, + max_seqlen_k: int, + q2k_indices: torch.Tensor | None = None, + **_kwargs, + ) -> "SparseDecodePagedAttentionWrapper": + self.page_table = page_table.contiguous() + self.seqused_k = seqused_k.contiguous() + self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous() + self.seqlen_q = int(seqlen_q) + self.max_seqlen_k = int(max_seqlen_k) + return self + + def run( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + *, + softmax_scale: float | None = None, + return_softmax_lse: bool = False, + **_kwargs, + ): + if self.page_table is None or self.seqused_k is None or self.seqlen_q is None or self.max_seqlen_k is None: + raise RuntimeError("decode wrapper must be planned before run") + return sparse_decode_atten_func( + q, k, v, self.q2k_indices, page_table=self.page_table, seqused_k=self.seqused_k, + seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, blk_kv=self.blk_kv, + causal=self.causal, softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse, + ) diff --git a/python/fmha_sm12x/_fp4.py b/python/fmha_sm12x/_fp4.py new file mode 100644 index 0000000..02bf217 --- /dev/null +++ b/python/fmha_sm12x/_fp4.py @@ -0,0 +1,189 @@ +# SPDX-License-Identifier: MIT + +"""SM12x FP4 indexer reference implementation.""" + +from __future__ import annotations + +import torch + +_PAGE_SIZE = 128 +_PACKED_D_BYTES = 64 +_HEAD_DIM = 128 +_PUBLIC_SCALE_LAYOUT = "public" +_PREORDERED_MMA_SCALE_LAYOUT = "preordered_mma" +_FP4_VALUES = ( + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +) + + +def _ceil_div(x: int, y: int) -> int: + return (int(x) + int(y) - 1) // int(y) + + +def _scale_groups(fp4_format: str) -> int: + match str(fp4_format).lower(): + case "mxfp4": + return 4 + case "nvfp4": + return 8 + case other: + raise ValueError(f"fp4_format must be 'mxfp4' or 'nvfp4', got {other!r}") + + +def _require_i32_vector(tensor: torch.Tensor, *, name: str, device: torch.device) -> None: + if tensor.device != device or tensor.dtype != torch.int32 or tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1 int32 on {device}") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def _as_u8(tensor: torch.Tensor, *, name: str, expected_ndim: int) -> torch.Tensor: + if tensor.ndim != expected_ndim: + raise ValueError(f"{name} must be rank {expected_ndim}") + if int(tensor.shape[-1]) != _PACKED_D_BYTES: + raise ValueError(f"{name}.shape[-1] must be 64 packed bytes") + if tensor.dtype == torch.uint8: + return tensor.contiguous() + return tensor.contiguous().view(torch.uint8) + + +def _unpack_fp4(packed: torch.Tensor) -> torch.Tensor: + lut = torch.tensor(_FP4_VALUES, dtype=torch.float32, device=packed.device) + u8 = packed.to(torch.uint8) + lo = u8 & 0x0F + hi = u8 >> 4 + out = torch.empty((*u8.shape[:-1], _HEAD_DIM), dtype=torch.float32, device=u8.device) + out[..., 0::2] = lut[lo.long()] + out[..., 1::2] = lut[hi.long()] + return out + + +def _restore_preordered_q_scale(scale: torch.Tensor, total_q: int, heads: int, groups: int) -> torch.Tensor: + public = torch.empty((total_q, heads, groups), dtype=scale.dtype, device=scale.device) + if scale.ndim == 6 and scale.shape[0] == 32: + for row in range(total_q): + r0 = row % 32 + r1 = (row // 32) % 4 + r2 = row // 128 + for group in range(groups): + public[row, :, group] = scale[r0, r1, r2, group % 4, group // 4, :heads] + return public + if scale.ndim == 6: + for row in range(total_q): + for group in range(groups): + public[row, :, group] = scale[:heads, row // 128, group // 4, row % 32, (row // 32) % 4, group % 4] + return public + raise ValueError("preordered q_scale must be a rank-6 MMA scale tensor") + + +def _restore_preordered_k_scale(scale: torch.Tensor, pages: int, heads: int, groups: int) -> torch.Tensor: + public = torch.empty((pages, heads, _PAGE_SIZE, groups), dtype=scale.dtype, device=scale.device) + if scale.ndim == 6 and scale.shape[0] == 32: + for page in range(pages): + for head in range(heads): + scale_l = page * heads + head + for row in range(_PAGE_SIZE): + for group in range(groups): + public[page, head, row, group] = scale[row % 32, (row // 32) % 4, 0, group % 4, group // 4, scale_l] + return public + if scale.ndim == 6: + for page in range(pages): + for head in range(heads): + scale_l = page * heads + head + for row in range(_PAGE_SIZE): + for group in range(groups): + public[page, head, row, group] = scale[scale_l, 0, group // 4, row % 32, (row // 32) % 4, group % 4] + return public + raise ValueError("preordered k_scale must be a rank-6 MMA scale tensor") + + +def _public_scales(scale: torch.Tensor, *, shape: tuple[int, ...], layout: str, fp4_format: str) -> torch.Tensor: + groups = _scale_groups(fp4_format) + if layout == _PUBLIC_SCALE_LAYOUT: + if tuple(scale.shape) != shape: + raise ValueError(f"scale must have shape {shape}, got {tuple(scale.shape)}") + return scale.contiguous() + if layout != _PREORDERED_MMA_SCALE_LAYOUT: + raise ValueError(f"scale_layout must be 'public' or 'preordered_mma', got {layout!r}") + if len(shape) == 3: + return _restore_preordered_q_scale(scale, shape[0], shape[1], groups) + return _restore_preordered_k_scale(scale, shape[0], shape[1], groups) + + +def _dequantize_fp4(packed: torch.Tensor, scale: torch.Tensor, *, fp4_format: str) -> torch.Tensor: + groups = _scale_groups(fp4_format) + logical = _unpack_fp4(packed) + scale_f = scale.to(torch.float32).repeat_interleave(_HEAD_DIM // groups, dim=-1) + return logical * scale_f + + +def fp4_indexer_block_scores( + q_fp4: torch.Tensor, + k_fp4: torch.Tensor, + q_scale: torch.Tensor, + k_scale: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + cu_page_offsets: torch.Tensor, + *, + max_seqlen_q: int, + max_seqlen_k: int, + kv_indices: torch.Tensor, + fp4_format: str, + causal: bool = False, + qo_offset: torch.Tensor | None = None, + scale_layout: str = _PUBLIC_SCALE_LAYOUT, +) -> torch.Tensor: + """Return dequantized FP4 QK max scores per 128-token KV page.""" + + q_bytes = _as_u8(q_fp4, name="q_fp4", expected_ndim=3) + k_bytes = _as_u8(k_fp4, name="k_fp4", expected_ndim=4) + total_q, heads_q, _ = (int(v) for v in q_bytes.shape) + pages, heads_k, page_size, _ = (int(v) for v in k_bytes.shape) + if page_size != _PAGE_SIZE: + raise ValueError(f"k_fp4 page size must be {_PAGE_SIZE}, got {page_size}") + if heads_q % heads_k != 0: + raise ValueError("q heads must be divisible by KV heads") + for name, tensor in (("cu_seqlens_q", cu_seqlens_q), ("cu_seqlens_k", cu_seqlens_k), ("cu_page_offsets", cu_page_offsets)): + _require_i32_vector(tensor, name=name, device=q_fp4.device) + if kv_indices.device != q_fp4.device or kv_indices.dtype != torch.int32 or kv_indices.ndim != 1: + raise ValueError("kv_indices must be rank-1 int32 on q_fp4.device") + batch = int(cu_seqlens_q.numel() - 1) + if qo_offset is not None: + _require_i32_vector(qo_offset, name="qo_offset", device=q_fp4.device) + if qo_offset.shape != (batch,): + raise ValueError("qo_offset must have shape [batch]") + q_scales = _public_scales(q_scale, shape=(total_q, heads_q, _scale_groups(fp4_format)), layout=scale_layout, fp4_format=fp4_format) + k_scales = _public_scales(k_scale, shape=(pages, heads_k, _PAGE_SIZE, _scale_groups(fp4_format)), layout=scale_layout, fp4_format=fp4_format) + q = _dequantize_fp4(q_bytes, q_scales, fp4_format=fp4_format) + k = _dequantize_fp4(k_bytes, k_scales, fp4_format=fp4_format) + max_tiles = _ceil_div(int(max_seqlen_k), _PAGE_SIZE) + scores = torch.full((heads_q, max_tiles, total_q), float("-inf"), dtype=torch.float32, device=q_fp4.device) + q_cpu = cu_seqlens_q.to("cpu", dtype=torch.int64, non_blocking=False) + k_cpu = cu_seqlens_k.to("cpu", dtype=torch.int64, non_blocking=False) + page_cpu = cu_page_offsets.to("cpu", dtype=torch.int64, non_blocking=False) + offset_cpu = None if qo_offset is None else qo_offset.to("cpu", dtype=torch.int64, non_blocking=False) + h_ratio = heads_q // heads_k + for b in range(batch): + q_begin = int(q_cpu[b]) + q_len = int(q_cpu[b + 1] - q_cpu[b]) + kv_len = int(k_cpu[b + 1] - k_cpu[b]) + page_begin = int(page_cpu[b]) + for local_q in range(q_len): + q_idx = q_begin + local_q + visible_limit = (int(offset_cpu[b]) if offset_cpu is not None else kv_len - q_len) + local_q + for tile in range(min(max_tiles, _ceil_div(kv_len, _PAGE_SIZE))): + phys_page = int(kv_indices[page_begin + tile].item()) + valid = min(_PAGE_SIZE, kv_len - tile * _PAGE_SIZE) + if valid <= 0: + continue + positions = torch.arange(valid, device=q_fp4.device, dtype=torch.long) + tile * _PAGE_SIZE + visible = positions <= visible_limit if causal else torch.ones_like(positions, dtype=torch.bool) + if not bool(visible.any().item()): + continue + for head in range(heads_q): + kv_head = head // h_ratio + logits = torch.matmul(k[phys_page, kv_head, :valid].float(), q[q_idx, head].float()) + scores[head, tile, q_idx] = logits[visible].max() + return scores diff --git a/python/fmha_sm12x/_lse.py b/python/fmha_sm12x/_lse.py new file mode 100644 index 0000000..849b343 --- /dev/null +++ b/python/fmha_sm12x/_lse.py @@ -0,0 +1,73 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import math + +import torch + +from ._reference import Sm12xPlan, _gather_kv, _page_offsets, _selected_positions, _validate_runtime_inputs + + +def _logsumexp_one( + query: torch.Tensor, + keys: torch.Tensor, + positions: torch.Tensor, + *, + visible_limit: int, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + if positions.numel() == 0: + return torch.tensor(float("-inf"), dtype=torch.float32, device=query.device) + visible = positions <= int(visible_limit) if causal else torch.ones_like(positions, dtype=torch.bool) + if not bool(visible.any().item()): + return torch.tensor(float("-inf"), dtype=torch.float32, device=query.device) + logits = torch.matmul(keys.index_select(0, positions).float(), query.float()) * float(sm_scale) + return torch.logsumexp(logits.masked_fill(~visible, float("-inf")), dim=0) + + +def run_lse( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + plan: Sm12xPlan, + *, + kv_indices: torch.Tensor | None, + kv_block_indexes: torch.Tensor | None, + sm_scale: float | None, +) -> torch.Tensor: + _validate_runtime_inputs(q, k, v, plan) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(int(q.shape[-1])) + page_size = plan.page_size if plan.page_size > 0 else 128 + h_ratio = plan.num_qo_heads // plan.num_kv_heads + lse = torch.empty(q.shape[:2], dtype=torch.float32, device=q.device) + q_start = 0 + kv_start = 0 + page_offsets = _page_offsets(plan.kv_segment_lens, page_size) + for batch, (qo_len, kv_len, offset) in enumerate( + zip(plan.qo_segment_lens.tolist(), plan.kv_segment_lens.tolist(), plan.qo_offset.tolist(), strict=True) + ): + keys_b, _ = _gather_kv( + k, v, kv_start=kv_start, kv_len=int(kv_len), page_begin=page_offsets[batch], + page_size=page_size, kv_indices=kv_indices, + ) + full_positions = torch.arange(int(kv_len), device=q.device, dtype=torch.long) + for local_q in range(int(qo_len)): + q_index = q_start + local_q + visible_limit = int(offset) + local_q + for head in range(plan.num_qo_heads): + kv_head = head // h_ratio + positions = full_positions + if kv_block_indexes is not None: + block_head = kv_head if kv_block_indexes.shape[1] == plan.num_kv_heads else head + positions = _selected_positions(kv_block_indexes[q_index, block_head], int(kv_len), page_size, q.device) + lse[q_index, head] = _logsumexp_one( + q[q_index, head], keys_b[:, kv_head], positions, + visible_limit=visible_limit, causal=plan.causal, sm_scale=float(sm_scale), + ) + q_start += int(qo_len) + if k.ndim == 3: + kv_start += int(kv_len) + return lse diff --git a/python/fmha_sm12x/_nvfp4.py b/python/fmha_sm12x/_nvfp4.py new file mode 100644 index 0000000..976816b --- /dev/null +++ b/python/fmha_sm12x/_nvfp4.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: MIT + +"""SM12x NVFP4 sparse-attention reference implementation.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Final + +import torch + + +NVFP4_BLOCK_SIZE: Final = 16 +NVFP4_FP4_MAX: Final = 6.0 +NVFP4_FP8_E4M3_MAX: Final = 448.0 +_HEAD_DIM: Final = 128 +_FP4_VALUES: Final = ( + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, +) + + +@dataclass(frozen=True, slots=True) +class Nvfp4QuantizedTensor: + """Packed NVFP4 tensor plus the metadata needed to dequantize it.""" + + data: torch.Tensor + scale_128x4: torch.Tensor + global_scale: torch.Tensor + logical_scale_shape: tuple[int, int] + original_shape: tuple[int, ...] + + +def _round_up(x: int, multiple: int) -> int: + return ((int(x) + int(multiple) - 1) // int(multiple)) * int(multiple) + + +def nvfp4_scale_128x4_offset(row: torch.Tensor, col: torch.Tensor, scale_cols: int) -> torch.Tensor: + """Return flat offsets for cuBLAS/cuDNN 128x4 rowwise scale storage.""" + + padded_cols = _round_up(scale_cols, 4) + tiles_n = padded_cols // 4 + tile_m = row // 128 + tile_n = col // 4 + outer = row % 128 + inner = col % 4 + return (tile_m * tiles_n + tile_n) * 512 + (outer % 32) * 16 + (outer // 32) * 4 + inner + + +def swizzle_nvfp4_scale_to_128x4(scale: torch.Tensor, *, rows: int, cols: int) -> torch.Tensor: + """Convert TE logical rowwise scales to cuBLAS/cuDNN 128x4 tiled layout.""" + + if scale.ndim != 2: + raise ValueError(f"scale must be 2D, got shape {tuple(scale.shape)}") + rows = int(rows) + cols = int(cols) + padded_rows = _round_up(rows, 128) + padded_cols = _round_up(cols, 4) + if scale.shape[0] < rows or scale.shape[1] < cols: + raise ValueError( + "scale is smaller than the requested logical shape: " + f"got {tuple(scale.shape)}, need at least {(rows, cols)}" + ) + logical = scale[:rows, :cols].contiguous() + if logical.shape != (padded_rows, padded_cols): + logical = torch.nn.functional.pad( + logical.to(torch.float32), + (0, padded_cols - cols, 0, padded_rows - rows), + ).to(scale.dtype) + swizzled = torch.empty_like(logical) + row = torch.arange(padded_rows, device=scale.device, dtype=torch.int64)[:, None] + col = torch.arange(padded_cols, device=scale.device, dtype=torch.int64)[None, :] + offset = nvfp4_scale_128x4_offset(row, col, padded_cols).reshape(-1) + swizzled.reshape(-1)[offset] = logical.reshape(-1) + return swizzled + + +def nvfp4_global_scale_from_amax(amax: torch.Tensor) -> torch.Tensor: + """Compute TE NVFP4 tensor/global dequant scale from rowwise amax.""" + + return amax.to(torch.float32) / (NVFP4_FP8_E4M3_MAX * NVFP4_FP4_MAX) + + +def _import_te_nvfp4_quantizer(): + try: + from transformer_engine.pytorch.tensor import NVFP4Quantizer + except (ImportError, OSError) as exc: # pragma: no cover - environment dependent + raise RuntimeError( + "Transformer Engine NVFP4 quantization is unavailable. Install a " + "Transformer Engine build with its PyTorch dependencies." + ) from exc + return NVFP4Quantizer + + +def quantize_bf16_to_nvfp4_128x4(x: torch.Tensor) -> Nvfp4QuantizedTensor: + """Quantize a BF16/FP16 tensor to NVFP4 using Transformer Engine.""" + + if not x.is_cuda: + raise ValueError("NVFP4 quantization requires a CUDA tensor") + if x.dtype not in (torch.bfloat16, torch.float16): + raise TypeError(f"x must be bf16 or fp16, got {x.dtype}") + if x.ndim < 2: + raise ValueError(f"x must have at least 2 dimensions, got {x.ndim}") + if x.shape[-1] % NVFP4_BLOCK_SIZE != 0: + raise ValueError(f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {x.shape[-1]}") + rows = 1 + for dim in x.shape[:-1]: + rows *= int(dim) + if rows % NVFP4_BLOCK_SIZE != 0: + raise ValueError(f"flattened row dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {rows}") + + quantizer_type = _import_te_nvfp4_quantizer() + quantizer = quantizer_type(rowwise=True, columnwise=False) + qx = quantizer.quantize(x.contiguous()) + meta = qx.get_metadata() + data = meta["rowwise_data"] + if data.dtype != torch.uint8: + data = data.view(torch.uint8) + scale_cols = int(x.shape[-1]) // NVFP4_BLOCK_SIZE + scale_128x4 = swizzle_nvfp4_scale_to_128x4(meta["rowwise_scale_inv"], rows=rows, cols=scale_cols) + return Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale_128x4, + global_scale=nvfp4_global_scale_from_amax(meta["amax_rowwise"]).contiguous(), + logical_scale_shape=(rows, scale_cols), + original_shape=tuple(int(v) for v in x.shape), + ) + + +def quantize_kv_bf16_to_nvfp4_128x4( + k: torch.Tensor, + v: torch.Tensor, +) -> tuple[Nvfp4QuantizedTensor, Nvfp4QuantizedTensor]: + """Quantize BF16/FP16 K and V tensors independently for KVFP4 attention.""" + + return quantize_bf16_to_nvfp4_128x4(k), quantize_bf16_to_nvfp4_128x4(v) + + +def _unpack_nvfp4(data: torch.Tensor, logical_dim: int) -> torch.Tensor: + lut = torch.tensor(_FP4_VALUES, dtype=torch.float32, device=data.device) + packed = data.contiguous().view(torch.uint8) + values = torch.empty((*packed.shape[:-1], logical_dim), dtype=torch.float32, device=data.device) + values[..., 0::2] = lut[(packed & 0x0F).long()] + values[..., 1::2] = lut[(packed >> 4).long()] + return values + + +def dequantize_nvfp4_128x4( + data: torch.Tensor, + scale_128x4: torch.Tensor, + global_scale: torch.Tensor | None, + *, + original_shape: tuple[int, ...], +) -> torch.Tensor: + """Dequantize packed NVFP4 data with cuBLAS/cuDNN 128x4 scales.""" + + if data.dtype != torch.uint8: + data = data.view(torch.uint8) + logical_dim = int(original_shape[-1]) + if logical_dim % NVFP4_BLOCK_SIZE != 0: + raise ValueError(f"last dimension must be divisible by {NVFP4_BLOCK_SIZE}, got {logical_dim}") + if data.shape[-1] * 2 != logical_dim: + raise ValueError("packed data last dimension does not match original shape") + rows = 1 + for dim in original_shape[:-1]: + rows *= int(dim) + scale_cols = logical_dim // NVFP4_BLOCK_SIZE + values = _unpack_nvfp4(data, logical_dim).reshape(rows, logical_dim) + row = torch.arange(rows, device=data.device, dtype=torch.int64)[:, None] + col = torch.arange(scale_cols, device=data.device, dtype=torch.int64)[None, :] + offsets = nvfp4_scale_128x4_offset(row, col, scale_cols) + scale = scale_128x4.reshape(-1)[offsets.reshape(-1)].reshape(rows, scale_cols) + scale_f = scale.view(torch.float8_e4m3fn).to(torch.float32).repeat_interleave(NVFP4_BLOCK_SIZE, dim=1) + out = values * scale_f + if global_scale is not None: + out = out * global_scale.reshape(-1)[0].to(torch.float32) + return out.reshape(original_shape).to(torch.bfloat16) + + +def dequantize_nvfp4_128x4_to_bf16( + qx: Nvfp4QuantizedTensor, + *, + include_global_scale: bool = True, +) -> torch.Tensor: + """Reference dequantization for validation of packed NVFP4 tensors.""" + + global_scale = qx.global_scale if include_global_scale else None + return dequantize_nvfp4_128x4( + qx.data, + qx.scale_128x4, + global_scale, + original_shape=qx.original_shape, + ) + + +def sparse_atten_nvfp4_kv_func( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + k_scale_128x4: torch.Tensor, + v_scale_128x4: torch.Tensor, + k_global_scale: torch.Tensor | None, + v_global_scale: torch.Tensor | None, + k2q_row_ptr: torch.Tensor, + k2q_q_indices: torch.Tensor, + topK: int, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + blk_kv: int = 128, + causal: bool = False, + softmax_scale: float | None = None, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + partial_dtype: "torch.dtype" = torch.bfloat16, + return_softmax_lse: bool = False, + page_table: torch.Tensor | None = None, + seqused_k: torch.Tensor | None = None, + schedule: object | None = None, + usable_SM_count: int = -1, + qk_dtype: "torch.dtype | None" = None, + pv_dtype: "torch.dtype | None" = None, + **_kwargs, +): + """Run SM12x sparse attention by dequantizing packed NVFP4 K/V. + + Mirrors the SM100 ``sparse_atten_nvfp4_kv_func`` surface, including the + temperature-scaled LSE outputs, by forwarding to ``sparse_atten_func``. + """ + + from .sparse import sparse_atten_func + + logical_shape = (*k.shape[:-1], _HEAD_DIM) + k_bf16 = dequantize_nvfp4_128x4(k, k_scale_128x4, k_global_scale, original_shape=logical_shape) + v_bf16 = dequantize_nvfp4_128x4(v, v_scale_128x4, v_global_scale, original_shape=logical_shape) + return sparse_atten_func( + q, k_bf16, v_bf16, k2q_row_ptr, k2q_q_indices, int(topK), + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=int(max_seqlen_q), max_seqlen_k=int(max_seqlen_k), + blk_kv=int(blk_kv), causal=bool(causal), softmax_scale=softmax_scale, + lse_temperature_scale=float(lse_temperature_scale), + return_temperature_lse=bool(return_temperature_lse), partial_dtype=partial_dtype, + return_softmax_lse=bool(return_softmax_lse), page_table=page_table, seqused_k=seqused_k, + schedule=schedule, usable_SM_count=int(usable_SM_count), qk_dtype=qk_dtype, pv_dtype=pv_dtype, + ) diff --git a/python/fmha_sm12x/_reference.py b/python/fmha_sm12x/_reference.py new file mode 100644 index 0000000..c9a2992 --- /dev/null +++ b/python/fmha_sm12x/_reference.py @@ -0,0 +1,247 @@ +# SPDX-License-Identifier: MIT + +"""Reference-correct SM12x attention paths. + +These are intentionally simple Torch implementations. They provide a safe +SM120/SM121 semantic target while the production kernels remain separate from +the SM100 tcgen05/TMEM implementation. +""" + +from __future__ import annotations + +from dataclasses import dataclass +import math + +import torch + + +@dataclass(frozen=True, slots=True) +class Sm12xPlan: + qo_segment_lens: torch.Tensor + kv_segment_lens: torch.Tensor + qo_offset: torch.Tensor + num_qo_heads: int + num_kv_heads: int + page_size: int + output_maxscore: bool + kv_block_num: int + causal: bool + + +def _cpu_i64(tensor: torch.Tensor) -> torch.Tensor: + return tensor.to(device="cpu", dtype=torch.int64, non_blocking=False).contiguous() + + +def make_plan( + qo_segment_lens: torch.Tensor, + kv_segment_lens: torch.Tensor, + *, + num_qo_heads: int, + num_kv_heads: int, + qo_offset: int | torch.Tensor | None, + page_size: int, + output_maxscore: bool, + kv_block_num: int, + causal: bool, +) -> Sm12xPlan: + qo_lens = _cpu_i64(qo_segment_lens) + kv_lens = _cpu_i64(kv_segment_lens) + if qo_lens.ndim != 1 or kv_lens.ndim != 1 or qo_lens.shape != kv_lens.shape: + raise ValueError("qo_segment_lens and kv_segment_lens must be same-shape rank-1 tensors") + if num_kv_heads == -1: + num_kv_heads = int(num_qo_heads) + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if qo_offset is None: + offset = kv_lens - qo_lens + elif isinstance(qo_offset, int): + offset = torch.full_like(qo_lens, int(qo_offset)) + else: + offset = _cpu_i64(qo_offset) + if offset.shape != qo_lens.shape: + raise ValueError("qo_offset must have shape [batch_size]") + if bool((qo_lens < 0).any().item()) or bool((kv_lens < 0).any().item()): + raise ValueError("qo_segment_lens and kv_segment_lens must be non-negative") + return Sm12xPlan( + qo_segment_lens=qo_lens, + kv_segment_lens=kv_lens, + qo_offset=offset, + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + page_size=int(page_size), + output_maxscore=bool(output_maxscore), + kv_block_num=int(kv_block_num), + causal=bool(causal), + ) + + +def _page_offsets(kv_lens: torch.Tensor, page_size: int) -> list[int]: + offsets = [0] + total = 0 + for kv_len in kv_lens.tolist(): + total += (int(kv_len) + page_size - 1) // page_size + offsets.append(total) + return offsets + + +def _gather_kv( + k: torch.Tensor, + v: torch.Tensor, + *, + kv_start: int, + kv_len: int, + page_begin: int, + page_size: int, + kv_indices: torch.Tensor | None, +) -> tuple[torch.Tensor, torch.Tensor]: + if k.ndim == 3: + return k[kv_start : kv_start + kv_len], v[kv_start : kv_start + kv_len] + if k.ndim != 4 or page_size <= 0: + raise ValueError("paged KV requires k/v shape [pages, heads, page_size, dim] and page_size > 0") + page_count = (kv_len + page_size - 1) // page_size + if kv_indices is None: + physical = torch.arange(page_begin, page_begin + page_count, device=k.device) + else: + physical = kv_indices[page_begin : page_begin + page_count].to(torch.long) + k_dense = k.index_select(0, physical).permute(0, 2, 1, 3).reshape(-1, k.shape[1], k.shape[3]) + v_dense = v.index_select(0, physical).permute(0, 2, 1, 3).reshape(-1, v.shape[1], v.shape[3]) + return k_dense[:kv_len], v_dense[:kv_len] + + +def _selected_positions(blocks: torch.Tensor, kv_len: int, page_size: int, device: torch.device) -> torch.Tensor: + valid = blocks[(blocks >= 0) & (blocks * page_size < kv_len)].to(torch.long) + if valid.numel() == 0: + return torch.empty((0,), dtype=torch.long, device=device) + starts = valid * page_size + rel = torch.arange(page_size, device=device, dtype=torch.long) + pos = (starts[:, None] + rel[None, :]).reshape(-1) + return pos[pos < kv_len] + + +def _attend_one( + query: torch.Tensor, + keys: torch.Tensor, + values: torch.Tensor, + positions: torch.Tensor, + *, + visible_limit: int, + causal: bool, + sm_scale: float, +) -> torch.Tensor: + if positions.numel() == 0: + return torch.zeros((values.shape[-1],), dtype=torch.float32, device=query.device) + visible = positions <= int(visible_limit) if causal else torch.ones_like(positions, dtype=torch.bool) + if not bool(visible.any().item()): + return torch.zeros((values.shape[-1],), dtype=torch.float32, device=query.device) + k_sel = keys.index_select(0, positions).float() + v_sel = values.index_select(0, positions).float() + logits = torch.matmul(k_sel, query.float()) * float(sm_scale) + logits = logits.masked_fill(~visible, float("-inf")) + return torch.matmul(torch.softmax(logits, dim=0), v_sel) + + +def _write_tile_scores( + max_score: torch.Tensor, + *, + head: int, + q_index: int, + query: torch.Tensor, + keys: torch.Tensor, + positions: torch.Tensor, + page_size: int, + visible_limit: int, + causal: bool, + sm_scale: float, +) -> None: + if positions.numel() == 0: + return + logits = torch.matmul(keys.index_select(0, positions).float(), query.float()) * float(sm_scale) + visible = positions <= int(visible_limit) if causal else torch.ones_like(positions, dtype=torch.bool) + tile_ids = torch.div(positions, page_size, rounding_mode="floor") + for tile in torch.unique(tile_ids).tolist(): + mask = (tile_ids == int(tile)) & visible + if bool(mask.any().item()): + max_score[head, int(tile), q_index] = logits[mask].max() + + +def _validate_runtime_inputs(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, plan: Sm12xPlan) -> None: + if q.ndim != 3: + raise ValueError("q must have shape [total_q, num_qo_heads, head_dim]") + if q.shape[1] != plan.num_qo_heads: + raise ValueError("q head count does not match plan") + if k.shape != v.shape: + raise ValueError("k and v shapes must match") + if k.ndim not in (3, 4): + raise ValueError("k/v must be dense [total_k, heads, dim] or paged [pages, heads, page, dim]") + total_q = int(plan.qo_segment_lens.sum().item()) + if int(q.shape[0]) != total_q: + raise ValueError("q.shape[0] must equal sum(qo_segment_lens)") + if k.ndim == 3 and int(k.shape[0]) < int(plan.kv_segment_lens.sum().item()): + raise ValueError("k/v sequence length must cover sum(kv_segment_lens)") + + +def run_plan( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + plan: Sm12xPlan, + *, + kv_indices: torch.Tensor | None, + kv_block_indexes: torch.Tensor | None, + out: torch.Tensor | None, + max_score: torch.Tensor | None, + sm_scale: float | None, + output_o: bool, + output_maxscore: bool, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + _validate_runtime_inputs(q, k, v, plan) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(int(q.shape[-1])) + page_size = plan.page_size if plan.page_size > 0 else 128 + h_ratio = plan.num_qo_heads // plan.num_kv_heads + result = out if out is not None else torch.empty( + (q.shape[0], plan.num_qo_heads, v.shape[-1]), dtype=torch.bfloat16, device=q.device + ) + want_o = bool(output_o) + want_score = bool(output_maxscore or plan.output_maxscore or max_score is not None) + score = max_score + if want_score and score is None: + max_tiles = max((int(x) + page_size - 1) // page_size for x in plan.kv_segment_lens.tolist()) + score = torch.full((plan.num_qo_heads, max_tiles, q.shape[0]), float("-inf"), dtype=torch.float32, device=q.device) + elif score is not None: + score.fill_(float("-inf")) + q_start = 0 + kv_start = 0 + page_offsets = _page_offsets(plan.kv_segment_lens, page_size) + for batch, (qo_len, kv_len, offset) in enumerate( + zip(plan.qo_segment_lens.tolist(), plan.kv_segment_lens.tolist(), plan.qo_offset.tolist(), strict=True) + ): + keys_b, values_b = _gather_kv( + k, v, kv_start=kv_start, kv_len=int(kv_len), page_begin=page_offsets[batch], + page_size=page_size, kv_indices=kv_indices, + ) + full_positions = torch.arange(int(kv_len), device=q.device, dtype=torch.long) + for local_q in range(int(qo_len)): + q_index = q_start + local_q + visible_limit = int(offset) + local_q + for head in range(plan.num_qo_heads): + kv_head = head // h_ratio + positions = full_positions + if kv_block_indexes is not None: + block_head = kv_head if kv_block_indexes.shape[1] == plan.num_kv_heads else head + positions = _selected_positions(kv_block_indexes[q_index, block_head], int(kv_len), page_size, q.device) + if want_o: + result[q_index, head] = _attend_one( + q[q_index, head], keys_b[:, kv_head], values_b[:, kv_head], positions, + visible_limit=visible_limit, causal=plan.causal, sm_scale=float(sm_scale), + ).to(result.dtype) + if score is not None: + _write_tile_scores( + score, head=head, q_index=q_index, query=q[q_index, head], + keys=keys_b[:, kv_head], positions=positions, page_size=page_size, + visible_limit=visible_limit, causal=plan.causal, sm_scale=float(sm_scale), + ) + q_start += int(qo_len) + if k.ndim == 3: + kv_start += int(kv_len) + return (result if want_o else None, score) diff --git a/python/fmha_sm12x/_topk.py b/python/fmha_sm12x/_topk.py new file mode 100644 index 0000000..d32dbde --- /dev/null +++ b/python/fmha_sm12x/_topk.py @@ -0,0 +1,158 @@ +# SPDX-License-Identifier: MIT + +"""SM12x JIT loader for the ``sparse_topk_select`` csrc kernel. + +The kernel source is arch-agnostic (no tcgen05/TMEM) and ships inside the +``fmha_sm100`` package. SM12x reuses that source read-only but compiles it for +the SM120/SM121 target via :mod:`minimax_msa.arch`, with its own cache dir, so +``fmha_sm100`` itself needs no SM12x-specific code and its SM100 build is +unaffected. +""" + +from __future__ import annotations + +import importlib.util +import os +import shutil +import subprocess +import threading +from pathlib import Path + +from minimax_msa.arch import ( + cuda_arch_cache_suffix, + nvcc_gencode_flags, + require_sm12x_csrc_arch, +) + + +def _fmha_sm100_dir() -> Path: + spec = importlib.util.find_spec("fmha_sm100") + if spec is None or spec.origin is None: + raise RuntimeError("fmha_sm100 package not found; cannot locate sparse_topk_select.cu") + return Path(spec.origin).resolve().parent + + +def _cache_base() -> Path: + explicit = os.environ.get("MINFER_FMHA_CACHE_DIR") + base = Path(explicit) if explicit else Path(os.path.expanduser("~/.cache/minfer/fmha_sm12x")) + suffix = cuda_arch_cache_suffix() + return base.parent / (base.name + suffix) if suffix else base + + +def _tvm_ffi_include() -> str: + import tvm_ffi + + tvm_dir = Path(tvm_ffi.__path__[0]) + for inc in (tvm_dir / "include", tvm_dir.parent / "include"): + if inc.exists(): + return str(inc) + raise RuntimeError("Cannot find TVM-FFI include directory; install apache-tvm-ffi") + + +def _cuda_home() -> str: + if "CUDA_HOME" in os.environ: + return os.environ["CUDA_HOME"] + nvcc = shutil.which("nvcc") + if nvcc: + return str(Path(nvcc).resolve().parent.parent) + for p in ("/usr/local/cuda", "/opt/cuda"): + if os.path.isdir(p): + return p + raise RuntimeError("Cannot find CUDA toolkit. Set CUDA_HOME.") + + +def _nvcc_flags(cache_dir: Path, csrc: Path, cutlass: Path) -> str: + flags = [ + "-O3", "-std=c++20", + "--expt-relaxed-constexpr", "--expt-extended-lambda", + *nvcc_gencode_flags(), + "-static-global-template-stub=false", + "-DFLASHINFER_ENABLE_BF16", + "-DFLASHINFER_ENABLE_FP8_E4M3", + "-DFLASHINFER_ENABLE_FP8_E5M2", + "-DFLASHINFER_ENABLE_FP8_E8M0", + "-DFLASHINFER_ENABLE_FP4_E2M1", + "-DFLASHINFER_ENABLE_F16", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-Xcudafe", "--diag_suppress=2908", + f"-I{csrc / 'include'}", + f"-I{cutlass / 'include'}", + f"-I{cutlass / 'tools' / 'util' / 'include'}", + f"-I{_tvm_ffi_include()}", + f"-I{cache_dir}", + "-use_fast_math", + "-DNDEBUG", "-Xptxas", "-O3", + "-Xcompiler", "-fPIC", + ] + return " ".join(flags) + + +_module = None +_lock = threading.Lock() + + +def _compile() -> Path: + require_sm12x_csrc_arch("fmha_sm12x.sparse_topk") + pkg = _fmha_sm100_dir() + csrc = pkg / "csrc" + cutlass = pkg / "cutlass" + cache_dir = _cache_base() / "sparse_topk" + so_path = cache_dir / "sparse_topk_select.so" + if so_path.exists(): + return so_path + + cache_dir.mkdir(parents=True, exist_ok=True) + shutil.copy2(csrc / "sparse_topk_select.cu", cache_dir / "sparse_topk_select.cu") + # tvm_ffi_utils.h is included by relative name; csrc/ is not on the include + # path, so stage it inside cache_dir (which is) like the source build does. + header = csrc / "tvm_ffi_utils.h" + dst = cache_dir / "tvm_ffi_utils.h" + if not dst.exists() or dst.read_text() != header.read_text(): + shutil.copy2(header, dst) + + nvcc = os.path.join(_cuda_home(), "bin", "nvcc") + obj = cache_dir / "sparse_topk_select.o" + ninja_content = f"""ninja_required_version = 1.5 + +nvcc = {nvcc} +nvcc_flags = {_nvcc_flags(cache_dir, csrc, cutlass)} + +rule nvcc_compile + command = $nvcc $nvcc_flags -c $in -o $out + description = Compiling $in + +rule nvcc_link + command = $nvcc -shared $in -o $out -lcuda + description = Linking $out + +build {obj}: nvcc_compile {cache_dir / "sparse_topk_select.cu"} +build {so_path}: nvcc_link {obj} +""" + (cache_dir / "build.ninja").write_text(ninja_content) + result = subprocess.run(["ninja", "-j1"], cwd=str(cache_dir), capture_output=True, text=True) + if result.returncode != 0: + raise RuntimeError( + f"sparse_topk_select compilation failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) + return so_path + + +def get_sparse_topk_module(): + """Compile (once) and load the SM12x ``sparse_topk_select`` module.""" + + global _module + if _module is not None: + return _module + with _lock: + if _module is not None: + return _module + so_path = _compile() + import tvm_ffi + + _module = tvm_ffi.load_module(str(so_path)) + return _module + + +__all__ = ["get_sparse_topk_module"] diff --git a/python/fmha_sm12x/_triton_sparse.py b/python/fmha_sm12x/_triton_sparse.py new file mode 100644 index 0000000..37109f9 --- /dev/null +++ b/python/fmha_sm12x/_triton_sparse.py @@ -0,0 +1,337 @@ +# SPDX-License-Identifier: MIT + +"""Triton block-sparse attention for SM12x dense-KV prefill/decode. + +This is the semi-optimized SM120/SM121 path: a fused flash-attention kernel +that attends each query only to its top-k selected KV blocks, matching the +Torch reference in ``_reference.py`` / ``_lse.py`` numerically. Triton is an +optional accelerator (it ships with torch on Linux); callers fall back to the +Torch reference when :func:`triton_dense_supported` returns False. +""" + +from __future__ import annotations + +import math + +import torch + +try: + import triton + import triton.language as tl + + _HAS_TRITON = True +except ImportError: # pragma: no cover - exercised only without triton + _HAS_TRITON = False + + +def triton_available() -> bool: + return _HAS_TRITON + + +def triton_dense_supported( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + page_size: int, +) -> bool: + """Whether the dense Triton fast path can serve this call. + + Dense KV only (rank-3 ``[total_k, Hkv, D]``); paged KV and exotic dtypes + route to the Torch reference. + """ + + if not _HAS_TRITON: + return False + if not (q.is_cuda and k.is_cuda and v.is_cuda): + return False + if q.ndim != 3 or k.ndim != 3 or v.ndim != 3: + return False + if q.dtype not in (torch.bfloat16, torch.float16) or k.dtype != q.dtype or v.dtype != q.dtype: + return False + head_dim = int(q.shape[-1]) + if head_dim > 256 or int(v.shape[-1]) != head_dim: + return False + if int(page_size) <= 0 or int(page_size) > 256: + return False + return True + + +if _HAS_TRITON: + + @triton.jit + def _sparse_attn_kernel( + q_ptr, + k_ptr, + v_ptr, + idx_ptr, + o_ptr, + lse_ptr, + tlse_ptr, + kv_start_ptr, + kv_len_ptr, + vis_ptr, + sm_scale, + inv_temp, + h_ratio, + topk, + page_size, + stride_qn, + stride_qh, + stride_qd, + stride_kn, + stride_kh, + stride_kd, + stride_vn, + stride_vh, + stride_vd, + stride_in, + stride_ih, + stride_it, + stride_on, + stride_oh, + stride_od, + stride_ln, + stride_lh, + D: tl.constexpr, + BLOCK_D: tl.constexpr, + BLOCK_K: tl.constexpr, + BLOCK_H: tl.constexpr, + CAUSAL: tl.constexpr, + RETURN_LSE: tl.constexpr, + RETURN_TLSE: tl.constexpr, + ): + # One program owns a single query token and one KV head; it attends + # the whole GQA group (h_ratio q-heads that share this KV head) + # against that query's selected blocks. Per-token selection keeps the + # result identical to the reference even when tokens in a block differ. + q_index = tl.program_id(0) + kv_head = tl.program_id(1) + kv_start = tl.load(kv_start_ptr + q_index) + kv_len = tl.load(kv_len_ptr + q_index) + vis = tl.load(vis_ptr + q_index) + + off_h = tl.arange(0, BLOCK_H) + off_d = tl.arange(0, BLOCK_D) + off_k = tl.arange(0, BLOCK_K) + h_mask = off_h < h_ratio + d_mask = off_d < D + qh = kv_head * h_ratio + off_h + + q = tl.load( + q_ptr + q_index * stride_qn + qh[:, None] * stride_qh + off_d[None, :] * stride_qd, + mask=h_mask[:, None] & d_mask[None, :], + other=0.0, + ).to(tl.float32) + + m_i = tl.full((BLOCK_H,), float("-inf"), dtype=tl.float32) + l_i = tl.zeros((BLOCK_H,), dtype=tl.float32) + acc = tl.zeros((BLOCK_H, BLOCK_D), dtype=tl.float32) + m2 = tl.full((BLOCK_H,), float("-inf"), dtype=tl.float32) + l2 = tl.zeros((BLOCK_H,), dtype=tl.float32) + + for t in range(topk): + bid = tl.load(idx_ptr + q_index * stride_in + kv_head * stride_ih + t * stride_it) + base = bid * page_size + if (bid >= 0) and (base < kv_len): + pos = base + off_k + pos_mask = (off_k < page_size) & (pos < kv_len) + if CAUSAL: + pos_mask = pos_mask & (pos <= vis) + kv_row = kv_start + pos + k = tl.load( + k_ptr + kv_row[None, :] * stride_kn + kv_head * stride_kh + off_d[:, None] * stride_kd, + mask=d_mask[:, None] & pos_mask[None, :], + other=0.0, + ).to(tl.float32) + qk = tl.dot(q, k, input_precision="ieee") * sm_scale + qk = tl.where(pos_mask[None, :], qk, float("-inf")) + + m_ij = tl.maximum(m_i, tl.max(qk, axis=1)) + # m_ij == -inf only when no position has ever been visible; + # force alpha=1 there so exp(-inf - -inf) never yields NaN. + alpha = tl.where(m_ij == float("-inf"), 1.0, tl.exp(m_i - m_ij)) + p = tl.where(pos_mask[None, :], tl.exp(qk - m_ij[:, None]), 0.0) + l_i = l_i * alpha + tl.sum(p, axis=1) + acc = acc * alpha[:, None] + v = tl.load( + v_ptr + kv_row[:, None] * stride_vn + kv_head * stride_vh + off_d[None, :] * stride_vd, + mask=pos_mask[:, None] & d_mask[None, :], + other=0.0, + ).to(tl.float32) + acc += tl.dot(p, v, input_precision="ieee") + m_i = m_ij + + if RETURN_TLSE: + tqk = qk * inv_temp + m2_ij = tl.maximum(m2, tl.max(tqk, axis=1)) + alpha2 = tl.where(m2_ij == float("-inf"), 1.0, tl.exp(m2 - m2_ij)) + p2 = tl.where(pos_mask[None, :], tl.exp(tqk - m2_ij[:, None]), 0.0) + l2 = l2 * alpha2 + tl.sum(p2, axis=1) + m2 = m2_ij + + has_mass = l_i > 0 + out = tl.where(has_mass[:, None], acc / tl.where(has_mass[:, None], l_i[:, None], 1.0), 0.0) + tl.store( + o_ptr + q_index * stride_on + qh[:, None] * stride_oh + off_d[None, :] * stride_od, + out.to(o_ptr.dtype.element_ty), + mask=h_mask[:, None] & d_mask[None, :], + ) + if RETURN_LSE: + lse = tl.where(has_mass, m_i + tl.log(tl.where(has_mass, l_i, 1.0)), float("-inf")) + tl.store(lse_ptr + q_index * stride_ln + qh * stride_lh, lse, mask=h_mask) + if RETURN_TLSE: + has_mass2 = l2 > 0 + tlse = tl.where(has_mass2, m2 + tl.log(tl.where(has_mass2, l2, 1.0)), float("-inf")) + tl.store(tlse_ptr + q_index * stride_ln + qh * stride_lh, tlse, mask=h_mask) + + +def _per_query_geometry( + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + total_q: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Per-query kv_start, kv_len and causal visibility limit. + + ``vis[q] = (kv_len - qo_len) + local_q`` reproduces the reference default + ``qo_offset = kv_lens - qo_lens`` so causal masking matches token-for-token. + """ + + device = cu_seqlens_q.device + cu_q = cu_seqlens_q.to(torch.int64) + cu_k = cu_seqlens_k.to(torch.int64) + qo_len = cu_q[1:] - cu_q[:-1] + kv_len_b = cu_k[1:] - cu_k[:-1] + offset_b = kv_len_b - qo_len + batch_id = torch.repeat_interleave(torch.arange(qo_len.numel(), device=device), qo_len) + if int(batch_id.numel()) != int(total_q): + raise ValueError("cu_seqlens_q does not sum to total_q") + local_q = torch.arange(total_q, device=device) - cu_q[batch_id] + kv_start_q = cu_k[batch_id].to(torch.int32) + kv_len_q = kv_len_b[batch_id].to(torch.int32) + vis_q = (offset_b[batch_id] + local_q).to(torch.int32) + return kv_start_q.contiguous(), kv_len_q.contiguous(), vis_q.contiguous() + + +def triton_sparse_atten_dense( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + block_indexes: torch.Tensor, + *, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + num_kv_heads: int, + page_size: int, + causal: bool = False, + sm_scale: float | None = None, + return_lse: bool = False, + lse_temperature_scale: float = 1.0, + return_temperature_lse: bool = False, + out_dtype: torch.dtype | None = None, +) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: + """Fused dense block-sparse attention. + + ``block_indexes`` is int32 ``[total_q, num_kv_heads, topk]`` with -1 + padding (q2k selections, shared across each GQA group). Returns + ``(out, lse_or_None, temperature_lse_or_None)``; LSE tensors are + ``[total_q, Hq]`` float32 and -inf where a query selects no visible block. + ``out_dtype`` defaults to ``q.dtype``; pass it to round the float32 + accumulator straight to the final dtype (the SM12x surface uses bf16). + """ + + if not _HAS_TRITON: + raise RuntimeError("triton_sparse_atten_dense requires Triton") + if block_indexes.dtype != torch.int32 or block_indexes.ndim != 3: + raise ValueError("block_indexes must be int32 [total_q, num_kv_heads, topk]") + total_q, num_qo_heads, head_dim = (int(x) for x in q.shape) + num_kv_heads = int(num_kv_heads) + if num_qo_heads % num_kv_heads != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(block_indexes.shape[0]) != total_q or int(block_indexes.shape[1]) != num_kv_heads: + raise ValueError("block_indexes shape must be [total_q, num_kv_heads, topk]") + h_ratio = num_qo_heads // num_kv_heads + topk = int(block_indexes.shape[2]) + if sm_scale is None: + sm_scale = 1.0 / math.sqrt(head_dim) + lse_temperature_scale = float(lse_temperature_scale) + return_temperature_lse = bool(return_temperature_lse) and bool(return_lse) + + q = q.contiguous() + k = k.contiguous() + v = v.contiguous() + block_indexes = block_indexes.contiguous() + kv_start_q, kv_len_q, vis_q = _per_query_geometry(cu_seqlens_q, cu_seqlens_k, total_q) + + out = torch.empty( + (total_q, num_qo_heads, v.shape[-1]), + dtype=out_dtype if out_dtype is not None else q.dtype, + device=q.device, + ) + lse = ( + torch.empty((total_q, num_qo_heads), dtype=torch.float32, device=q.device) + if return_lse + else None + ) + tlse = ( + torch.empty((total_q, num_qo_heads), dtype=torch.float32, device=q.device) + if return_temperature_lse + else None + ) + lse_view = lse if lse is not None else out # unused when RETURN_LSE is False + tlse_view = tlse if tlse is not None else out + stride_ln = lse_view.stride(0) if lse is not None else 0 + stride_lh = lse_view.stride(1) if lse is not None else 0 + + # tl.dot needs all three tile dims (M GQA-group rows, N KV columns, and + # the K contraction = head dim) >= 16, so pad up; extra rows/columns/lanes + # are masked off in the kernel. + block_h = max(16, triton.next_power_of_2(h_ratio)) + block_d = max(16, triton.next_power_of_2(head_dim)) + block_k = max(16, triton.next_power_of_2(int(page_size))) + grid = (total_q, num_kv_heads) + _sparse_attn_kernel[grid]( + q, + k, + v, + block_indexes, + out, + lse_view, + tlse_view, + kv_start_q, + kv_len_q, + vis_q, + float(sm_scale), + 1.0 / lse_temperature_scale, + h_ratio, + topk, + int(page_size), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + block_indexes.stride(0), + block_indexes.stride(1), + block_indexes.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + stride_ln, + stride_lh, + D=head_dim, + BLOCK_D=block_d, + BLOCK_K=block_k, + BLOCK_H=block_h, + CAUSAL=bool(causal), + RETURN_LSE=bool(return_lse), + RETURN_TLSE=bool(return_temperature_lse), + num_warps=4, + ) + return out, lse, tlse + + +__all__ = ["triton_available", "triton_dense_supported", "triton_sparse_atten_dense"] diff --git a/python/fmha_sm12x/api.py b/python/fmha_sm12x/api.py new file mode 100644 index 0000000..027f0a0 --- /dev/null +++ b/python/fmha_sm12x/api.py @@ -0,0 +1,118 @@ +# SPDX-License-Identifier: MIT + +"""SM120/SM121 reference API for MiniMax Sparse Attention.""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +from ._topk import get_sparse_topk_module + +from ._reference import Sm12xPlan, make_plan, run_plan + +__all__ = ["Sm12xPlan", "fmha_sm12x_plan", "fmha_sm12x", "sparse_topk_select"] + + +def fmha_sm12x_plan( + qo_segment_lens: torch.Tensor, + kv_segment_lens: torch.Tensor, + num_qo_heads: int, + num_kv_heads: int = -1, + qo_offset: Optional[Union[int, torch.Tensor]] = None, + num_kv_splits: int = -1, + page_size: int = -1, + output_maxscore: bool = False, + kv_block_num: int = -1, + usable_SM_count: int = -1, + causal: bool = True, + **_kwargs, +) -> Sm12xPlan: + """Build a semantic SM12x reference plan. + + ``num_kv_splits`` and ``usable_SM_count`` are accepted for API compatibility + but are intentionally ignored by this Torch reference backend. + """ + + _ = (num_kv_splits, usable_SM_count) + return make_plan( + qo_segment_lens, + kv_segment_lens, + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + qo_offset=qo_offset, + page_size=int(page_size), + output_maxscore=bool(output_maxscore), + kv_block_num=int(kv_block_num), + causal=bool(causal), + ) + + +def fmha_sm12x( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + plan_info: Sm12xPlan, + kv_indices: Optional[torch.Tensor] = None, + kv_block_indexes: Optional[torch.Tensor] = None, + q_offset_override: Optional[Union[int, torch.Tensor]] = None, + out: Optional[torch.Tensor] = None, + max_score: Optional[torch.Tensor] = None, + **kwargs, +) -> tuple[torch.Tensor | None, torch.Tensor | None]: + """Run the SM12x Torch reference attention path.""" + + if q_offset_override is not None: + plan_info = make_plan( + plan_info.qo_segment_lens, + plan_info.kv_segment_lens, + num_qo_heads=plan_info.num_qo_heads, + num_kv_heads=plan_info.num_kv_heads, + qo_offset=q_offset_override, + page_size=plan_info.page_size, + output_maxscore=plan_info.output_maxscore, + kv_block_num=plan_info.kv_block_num, + causal=plan_info.causal, + ) + return run_plan( + q, + k, + v, + plan_info, + kv_indices=kv_indices, + kv_block_indexes=kv_block_indexes, + out=out, + max_score=max_score, + sm_scale=kwargs.get("sm_scale"), + output_o=bool(kwargs.get("output_o", True)), + output_maxscore=bool(kwargs.get("output_maxscore", False)), + ) + + +def sparse_topk_select( + max_score: torch.Tensor, + topk: int, + num_valid_pages: Optional[int] = None, + output: Optional[torch.Tensor] = None, + force_begin_blocks: int = 0, + force_end_blocks: int = 0, +) -> torch.Tensor: + """SM12x-safe wrapper around the standalone CUDA top-k selector.""" + + if max_score.dtype != torch.float32: + raise TypeError(f"max_score must be float32, got {max_score.dtype}") + if max_score.ndim != 3 or not max_score.is_contiguous(): + raise ValueError("max_score must be contiguous with shape [heads, max_k_tiles, total_q]") + if int(topk) != 16: + raise ValueError(f"topk must be 16, got {topk}") + heads, max_k_tiles, total_q = (int(v) for v in max_score.shape) + if output is None: + output = torch.empty((total_q, heads, int(topk)), dtype=torch.int32, device=max_score.device) + valid_pages = max_k_tiles if num_valid_pages is None else int(num_valid_pages) + workspace = torch.empty((heads * max_k_tiles * total_q,), dtype=torch.int32, device=max_score.device) + get_sparse_topk_module().sparse_topk_select( + max_score, output, workspace, int(topk), int(valid_pages), + int(force_begin_blocks), int(force_end_blocks), torch.cuda.current_stream().cuda_stream, + ) + return output diff --git a/python/fmha_sm12x/arch.py b/python/fmha_sm12x/arch.py new file mode 100644 index 0000000..9291542 --- /dev/null +++ b/python/fmha_sm12x/arch.py @@ -0,0 +1,21 @@ +# SPDX-License-Identifier: MIT + +"""SM12x-facing architecture helpers.""" + +from minimax_msa.arch import ( # noqa: F401 + CudaArch, + cpp_extension_arch_flag, + cuda_arch_cache_suffix, + nvcc_gencode_flags, + require_sm12x_csrc_arch, + selected_cuda_arch, +) + +__all__ = [ + "CudaArch", + "cpp_extension_arch_flag", + "cuda_arch_cache_suffix", + "nvcc_gencode_flags", + "require_sm12x_csrc_arch", + "selected_cuda_arch", +] diff --git a/python/fmha_sm12x/cute/__init__.py b/python/fmha_sm12x/cute/__init__.py new file mode 100644 index 0000000..d5a8c89 --- /dev/null +++ b/python/fmha_sm12x/cute/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT + +"""SM12x CuTe/CUDA helper package.""" diff --git a/python/fmha_sm12x/cute/src/__init__.py b/python/fmha_sm12x/cute/src/__init__.py new file mode 100644 index 0000000..48db8b3 --- /dev/null +++ b/python/fmha_sm12x/cute/src/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT + +"""SM12x helper kernel source package.""" diff --git a/python/fmha_sm12x/cute/src/sm12x/__init__.py b/python/fmha_sm12x/cute/src/sm12x/__init__.py new file mode 100644 index 0000000..fc13bd1 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT + +"""SM120/SM121 helper kernels.""" diff --git a/python/fmha_sm12x/cute/src/sm12x/_schedule.py b/python/fmha_sm12x/cute/src/sm12x/_schedule.py new file mode 100644 index 0000000..d73f7b2 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/_schedule.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: MIT + +"""Host-side schedule metadata helpers for SM12x sparse attention.""" + +from __future__ import annotations + +from dataclasses import dataclass + +import torch + + +@dataclass(slots=True) +class SparseAttentionSchedule: + enabled: bool + scheduler_metadata: torch.Tensor | None + work_count: torch.Tensor | None + qsplit_indices: torch.Tensor | None = None + split_counts: torch.Tensor | None = None + target_q_per_cta: int = 0 + + @property + def work_capacity(self) -> int: + return 0 if self.scheduler_metadata is None else int(self.scheduler_metadata.shape[0]) + + +SparseSchedulePlan = SparseAttentionSchedule + + +class SparseAttentionScheduleModel: + """Host-side helpers for sparse attention schedule sizing.""" + + @staticmethod + def _round_up(x: int, y: int) -> int: + return ((x + y - 1) // y) * y + + @staticmethod + def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + def _target_q_per_cta( + self, + *, + total_q: int, + topk: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + num_sm = torch.cuda.get_device_properties(device).multi_processor_count + if usable_SM_count > 0: + num_sm = min(int(usable_SM_count), num_sm) + q_tokens_per_group = 128 // qhead_per_kv + total_refs_upper = total_q * topk * head_kv + desired_work_items = max(num_sm * 2, 1) + total_groups_upper = self._ceil_div(max(total_refs_upper, 1), q_tokens_per_group) + target_groups_per_cta = min( + 512, + max(1, self._ceil_div(total_groups_upper, desired_work_items)), + ) + return target_groups_per_cta * q_tokens_per_group + + def balanced_target_q_per_cta( + self, + *, + total_q: int, + topk: int, + blk_kv: int, + head_kv: int, + qhead_per_kv: int, + device: torch.device, + usable_SM_count: int = -1, + ) -> int: + q_tokens_per_group = 128 // qhead_per_kv + occupancy_target = self._target_q_per_cta( + total_q=total_q, + topk=topk, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + usable_SM_count=usable_SM_count, + ) + sink_balance_cap = max(q_tokens_per_group, int(topk) * int(blk_kv) * 2) + target = min(max(occupancy_target, q_tokens_per_group), sink_balance_cap) + return self._round_up(target, q_tokens_per_group) + + def flat_schedule_capacity( + self, + *, + total_rows: int, + total_q: int, + topk: int, + head_kv: int, + target_q_per_cta: int, + ) -> int: + row_upper = max(total_rows, 0) * max(head_kv, 1) + refs_upper = max(total_q, 0) * max(topk, 1) * max(head_kv, 1) + split_upper = self._ceil_div(max(refs_upper, 1), max(target_q_per_cta, 1)) + return max(1, row_upper + split_upper) + + +SPARSE_SCHEDULE_MODEL = SparseAttentionScheduleModel() diff --git a/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/__init__.py b/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/__init__.py new file mode 100644 index 0000000..573d0c9 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/__init__.py @@ -0,0 +1,127 @@ +# SPDX-License-Identifier: MIT + +"""JIT-loaded CUDA C++ extension for the SM12x q2k -> k2q CSR builder.""" + +from __future__ import annotations + +import os + +import torch +from torch.utils.cpp_extension import load + +from minimax_msa.arch import ( + cpp_extension_arch_flag, + cuda_arch_cache_suffix, + require_sm12x_csrc_arch, +) + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SRC = os.path.join(_THIS_DIR, "build_k2q_csr.cu") + +_EXTRA_CFLAGS = ["-O3"] +_EXTRA_CUDA_CFLAGS_BASE = [ + "-O3", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=-v", + "--expt-relaxed-constexpr", +] + +_ext = None + + +def _cccl_include_flags() -> list[str]: + cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") + cccl = os.path.join(cuda_home, "include", "cccl") + return [f"-I{cccl}"] if os.path.isdir(cccl) else [] + + +def _load_ext(): + global _ext + if _ext is None: + require_sm12x_csrc_arch("fmha_sm12x.k2q_csr") + _ext = load( + name=f"sparse_build_k2q_csr_sm12x_ext{cuda_arch_cache_suffix()}", + sources=[_SRC], + extra_cflags=_EXTRA_CFLAGS, + extra_cuda_cflags=[ + *_EXTRA_CUDA_CFLAGS_BASE, + cpp_extension_arch_flag(), + *_cccl_include_flags(), + ], + verbose=False, + ) + return _ext + + +def run_build_k2q_csr( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, +) -> None: + """In-place fill of ``row_ptr`` and ``q_idx`` using the SM12x CUDA helper.""" + + _load_ext().run_build_k2q_csr( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + ) + + +def run_build_k2q_csr_with_schedule( + q2k: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + row_ptr: torch.Tensor, + q_idx: torch.Tensor, + scheduler_metadata: torch.Tensor, + work_count: torch.Tensor, + qsplit_idx: torch.Tensor, + split_counts: torch.Tensor, + topk: int, + blk_kv: int, + total_rows: int, + max_kv_blocks: int, + target_q_per_cta: int, + work_capacity: int, + max_seqlen_q: int, +) -> None: + """In-place fill of CSR plus sparse-attention schedule metadata.""" + + _load_ext().run_build_k2q_csr_with_schedule( + q2k, + cu_seqlens_q, + cu_seqlens_k, + row_ptr, + q_idx, + scheduler_metadata, + work_count, + qsplit_idx, + split_counts, + int(topk), + int(blk_kv), + int(total_rows), + int(max_kv_blocks), + int(target_q_per_cta), + int(work_capacity), + int(max_seqlen_q), + ) + + +def is_supported(topk: int, blk_kv: int) -> bool: + return int(topk) in (4, 8, 16, 32) and int(blk_kv) == 128 + + +__all__ = ["run_build_k2q_csr", "run_build_k2q_csr_with_schedule", "is_supported"] diff --git a/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/build_k2q_csr.cu b/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/build_k2q_csr.cu new file mode 100644 index 0000000..17eb85c --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/build_k2q_csr/build_k2q_csr.cu @@ -0,0 +1,873 @@ +// SPDX-License-Identifier: MIT + +// CUDA C++ q2k -> k2q CSR builder. +// +// Five-stage pipeline. q-ascending order within each CSR row is preserved +// by partitioning q across (CTA, warp_in_CTA) units; each unit owns a +// contiguous q-sub-range and reserves a contiguous slot range per row via +// a precomputed exclusive prefix scan. +// +// M: build_row_map -- round-robin packing of rows across batches +// H: histogram + tile_counts +// PR: row prefix -- single block per head, row_counts -> row_ptr +// PT: tile prefix -- multi-block, scan tile_counts along (c, w) axis +// S: scatter (sorted) -- per-warp slot range, q-sequential within warp +// +// Per-warp partitioning: each CTA has kWarps warps; warp w of CTA c owns +// q-range [c*q_per_cta + w*q_per_warp, c*q_per_cta + (w+1)*q_per_warp). +// tile_counts is shaped [G * kWarps, H, total_rows]; the "row" dimension +// of the prefix scan is the flattened (c * kWarps + w) index, scanned in +// lexicographic order so that warp-local slot ranges concatenate to the +// global q-sorted output. + +#include +#include +#include +#include + +#include + +#define CHECK_CUDA(x) TORCH_CHECK((x).is_cuda(), #x " must be CUDA") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK((x).is_contiguous(), #x " must be contiguous") +#define CHECK_INT(x) TORCH_CHECK((x).scalar_type() == at::kInt, #x " must be int32") +#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x); CHECK_INT(x) + +namespace { + +constexpr int kWarpSize = 32; + +__device__ __forceinline__ void advance_batch_only( + int const* __restrict__ cu_q, int B, int q_abs, int& bi) +{ + while (bi < B && cu_q[bi + 1] <= q_abs) ++bi; +} + +// Atomic increment of a 16-bit half within a 32-bit SMEM word; returns the +// OLD 16-bit value (slot). Per-warp count must stay < 32768 so the low +// half does not carry into the high half. +// base_int32 : int32 pointer; element i holds rows 2*i (low) and 2*i+1 (high). +__device__ __forceinline__ int atomic_inc_int16_packed( + int* base_int32, int row) +{ + int idx = row >> 1; + int shift = (row & 1) << 4; // 0 or 16 + int delta = 1 << shift; + int old = atomicAdd(&base_int32[idx], delta); + return (old >> shift) & 0xFFFF; +} + +// Read 16-bit half from packed int32 storage. +__device__ __forceinline__ int read_int16_packed(int const* base_int32, int row) { + int v = base_int32[row >> 1]; + int shift = (row & 1) << 4; + return (v >> shift) & 0xFFFF; +} + +// --------------------------------------------------------------------------- +// M: round-robin row map. +// --------------------------------------------------------------------------- +template +__global__ void k2q_build_row_map_kernel( + int const* __restrict__ cu_k, + int* __restrict__ row_map, + int* __restrict__ row_coords, + int B, + int max_kv_blocks) +{ + int level = blockIdx.x; + if (level >= max_kv_blocks) return; + if (threadIdx.x != 0) return; + int rows_before = 0; + for (int b = 0; b < B; ++b) { + int rb = (cu_k[b + 1] - cu_k[b] + kBlockK - 1) / kBlockK; + rows_before += (rb < level ? rb : level); + } + int active_before = 0; + for (int b = 0; b < B; ++b) { + int rb = (cu_k[b + 1] - cu_k[b] + kBlockK - 1) / kBlockK; + if (rb > level) { + int row_linear = rows_before + active_before; + row_map[(size_t)b * max_kv_blocks + level] = row_linear; + if (row_coords != nullptr) { + row_coords[(size_t)row_linear * 2] = b; + row_coords[(size_t)row_linear * 2 + 1] = level; + } + ++active_before; + } else { + row_map[(size_t)b * max_kv_blocks + level] = -1; + } + } +} + +// --------------------------------------------------------------------------- +// H: per-warp histogram + tile_counts. +// kWarps warps per CTA, each owns q-sub-range = q_per_cta / kWarps. +// SMEM hist[kWarps, total_rows] int32 (stored as packed int16 cursor: +// 2 entries per int32 word). Each warp counts to its own row. +// At end-of-CTA, write tile_counts[c*kWarps + w, h, r] = smem_hist[w, r] +// and atomicAdd(row_counts[h, r], sum over w of smem_hist[w, r]). +// --------------------------------------------------------------------------- +template +__global__ void k2q_hist_kernel( + int const* __restrict__ q2k, + int const* __restrict__ cu_q, + int const* __restrict__ row_map, + int* __restrict__ row_counts, + int* __restrict__ tile_counts, + int H, int B, int S_Q, + int total_rows, int max_kv_blocks, + int q_per_cta, int q_per_warp) +{ + constexpr int kThreads = kWarps * kWarpSize; + extern __shared__ int smem_hist_int[]; + int* smem_hist = smem_hist_int; + int tid = threadIdx.x; + int warp_id = tid >> 5; + int lane = tid & 31; + int c = blockIdx.x; + int q_start_cta = c * q_per_cta; + int q_end_cta = min(q_start_cta + q_per_cta, S_Q); + int q_start_warp = min(q_start_cta + warp_id * q_per_warp, q_end_cta); + int q_end_warp = min(q_start_warp + q_per_warp, q_end_cta); + + constexpr int kInt4PerToken = kTopK / 4; + int packed_per_warp = (total_rows + 1) >> 1; + int* my_hist = smem_hist + warp_id * packed_per_warp; + + for (int h = 0; h < H; ++h) { + for (int i = lane; i < packed_per_warp; i += kWarpSize) my_hist[i] = 0; + __syncthreads(); + + if (q_start_warp < q_end_warp) { + int bi = 0; + int qi = q_start_warp + lane; + advance_batch_only(cu_q, B, qi, bi); + + int4 const* head_topk4 = + reinterpret_cast(q2k + (size_t)h * S_Q * kTopK); + + for (; qi < q_end_warp; qi += kWarpSize) { + advance_batch_only(cu_q, B, qi, bi); + int const* my_row_map = row_map + (size_t)bi * max_kv_blocks; + + int4 buf[kInt4PerToken]; + #pragma unroll + for (int v = 0; v < kInt4PerToken; ++v) { + buf[v] = head_topk4[(size_t)qi * kInt4PerToken + v]; + } + #pragma unroll + for (int t = 0; t < kTopK; ++t) { + int kvb_local = reinterpret_cast(buf)[t]; + if (kvb_local >= 0 && kvb_local < max_kv_blocks) { + int row = my_row_map[kvb_local]; + if (row >= 0 && row < total_rows) { + atomic_inc_int16_packed(my_hist, row); + } + } + } + } + } + __syncthreads(); + + int* head_row_counts = row_counts + (size_t)h * total_rows; + // Each warp writes its own slice of tile_counts (full int32) by + // unpacking int16 entries from SMEM. + int* my_tile = tile_counts + + ((size_t)(c * kWarps + warp_id) * H + h) * total_rows; + for (int i = lane; i < total_rows; i += kWarpSize) { + my_tile[i] = read_int16_packed(my_hist, i); + } + __syncthreads(); + + // Sum across warps (int32 accumulator), atomicAdd to row_counts. + for (int i = tid; i < total_rows; i += kThreads) { + int sum = 0; + #pragma unroll + for (int w = 0; w < kWarps; ++w) { + sum += read_int16_packed(smem_hist + w * packed_per_warp, i); + } + if (sum > 0) atomicAdd(&head_row_counts[i], sum); + } + if (h + 1 < H) __syncthreads(); + } +} + +// --------------------------------------------------------------------------- +// PR: row prefix. One block per head. +// --------------------------------------------------------------------------- +template +__global__ void k2q_row_prefix_kernel( + int const* __restrict__ row_counts, + int* __restrict__ row_ptr, + int const* __restrict__ row_coords, + int* __restrict__ scheduler_metadata, + int* __restrict__ work_count, + int total_rows, + int target_q_per_cta, + int work_capacity) +{ + int h = blockIdx.x; + int tid = threadIdx.x; + __shared__ int scan_buf[kThreads]; + + int const* head_counts = row_counts + (size_t)h * total_rows; + int* head_rowptr = row_ptr + (size_t)h * (total_rows + 1); + int chunk = (total_rows + kThreads - 1) / kThreads; + int lo = tid * chunk; + int hi = min(lo + chunk, total_rows); + + int local_sum = 0; + for (int i = lo; i < hi; ++i) local_sum += head_counts[i]; + scan_buf[tid] = local_sum; + __syncthreads(); + + for (int off = 1; off < kThreads; off <<= 1) { + int add = (tid >= off) ? scan_buf[tid - off] : 0; + __syncthreads(); + scan_buf[tid] += add; + __syncthreads(); + } + int running = scan_buf[tid] - local_sum; + for (int i = lo; i < hi; ++i) { + int row_count = head_counts[i]; + running += row_count; + head_rowptr[i + 1] = running; + if (scheduler_metadata != nullptr && work_count != nullptr && row_count > 0) { + int num_chunks = (row_count + target_q_per_cta - 1) / target_q_per_cta; + int base = atomicAdd(work_count, num_chunks); + int batch_idx = row_coords[(size_t)i * 2]; + int kv_block_idx = row_coords[(size_t)i * 2 + 1]; + for (int c = 0; c < num_chunks; ++c) { + int work_idx = base + c; + if (work_idx < work_capacity) { + int q_begin = c * target_q_per_cta; + int q_count = min(target_q_per_cta, row_count - q_begin); + int* meta = scheduler_metadata + (size_t)work_idx * 6; + meta[0] = h; + meta[1] = i; + meta[2] = q_begin; + meta[3] = q_count; + meta[4] = batch_idx; + meta[5] = kv_block_idx; + } + } + } + } +} + +// --------------------------------------------------------------------------- +// PT_smem: SMEM-staged tile prefix scan. +// Each block handles kRowsPerBlock rows for one head h. Cooperative load +// of tile_counts[*, h, base_r..base_r+M) into SMEM (better coalescing +// than per-warp uncoalesced stride reads), then per-warp scan in SMEM, +// then cooperative store back. Fuses row_ptr into the base. +// --------------------------------------------------------------------------- +template +__global__ void k2q_tile_prefix_smem_kernel( + int* __restrict__ tile_counts, + int const* __restrict__ row_ptr, + int H, int total_rows, int G_total) +{ + static_assert(kRowsPerBlock > 0, "kRowsPerBlock must be positive"); + extern __shared__ int smem_tprefix[]; + // smem layout: smem[r_off][g] for r_off in [0, M), g in [0, G_total). + + int tid = threadIdx.x; + int lane = tid & 31; + int warp_id = tid >> 5; + + // Grid: H * blocks_per_h. Each block stays within a single head h + // and processes kRowsPerBlock contiguous rows starting at b_in_h * + // kRowsPerBlock. (Earlier flat-grid mapping `h = block_job / + // total_rows; base_r = block_job - h*total_rows` skipped rows when + // total_rows was not a multiple of kRowsPerBlock and H > 1, because + // the last partial block of head h-1 left blocks of head h starting + // at a non-zero row offset.) + int blocks_per_h = (total_rows + kRowsPerBlock - 1) / kRowsPerBlock; + int h = blockIdx.x / blocks_per_h; + int b_in_h = blockIdx.x - h * blocks_per_h; + if (h >= H) return; + int base_r = b_in_h * kRowsPerBlock; + if (base_r >= total_rows) return; + int actual_M = min(kRowsPerBlock, total_rows - base_r); + + size_t stride_g = (size_t)H * total_rows; + int* base_ptr = tile_counts + (size_t)h * total_rows + base_r; + int total_elems = G_total * actual_M; + + // Cooperative load. Pattern: thread tid -> (r_off=tid%M, g=tid/M), + // then strided. 32 lanes hit M r's x (32/M) g's, giving 32/M cache + // lines per warp (vs 32 in the naive stride-along-g pattern). + for (int i = tid; i < total_elems; i += kThreads) { + int r_off = i % actual_M; + int g = i / actual_M; + smem_tprefix[r_off * G_total + g] = base_ptr[g * stride_g + r_off]; + } + __syncthreads(); + + // Per-warp scan: warp w scans row (base_r + w) if w < actual_M. + if (warp_id < actual_M) { + int abs_r = base_r + warp_id; + int rp = row_ptr[(size_t)h * (total_rows + 1) + abs_r]; + int* my_smem = smem_tprefix + warp_id * G_total; + int running = rp; + for (int g0 = 0; g0 < G_total; g0 += kWarpSize) { + int g = g0 + lane; + int v = (g < G_total) ? my_smem[g] : 0; + int x = v; + #pragma unroll + for (int off = 1; off < kWarpSize; off <<= 1) { + int nbr = __shfl_up_sync(0xFFFFFFFF, x, off); + if (lane >= off) x += nbr; + } + int excl = running + x - v; + if (g < G_total) my_smem[g] = excl; + int chunk_sum = __shfl_sync(0xFFFFFFFF, x, 31); + running += chunk_sum; + } + } + __syncthreads(); + + // Cooperative store back. + for (int i = tid; i < total_elems; i += kThreads) { + int r_off = i % actual_M; + int g = i / actual_M; + base_ptr[g * stride_g + r_off] = smem_tprefix[r_off * G_total + g]; + } +} + +// --------------------------------------------------------------------------- +// S: scatter. kWarps warps per CTA, each owns q-sub-range. Per-warp SMEM +// cursor and per-warp tile_offset slot range. Within a warp, q's are +// processed sequentially; lanes 0..kTopK-1 handle the topK slots in +// lockstep. Across distinct q's in the same warp, the lockstep ordering +// guarantees q-monotonic atomicAdd on smem_cursor[r]. +// --------------------------------------------------------------------------- +// kQPerIter * kTopK lanes are active per warp iter; remaining lanes idle. +// For kTopK=16, kQPerIter=2 uses all 32 lanes; for kTopK=8, kQPerIter=4. +// CORRECTNESS NOTE: relies on lane-ordered SMEM atomicAdd return values +// within a single warp instruction (verified on SM100; tests pass). +// +// SMEM cursor stored as packed int16 (two cursors per int32). Per-warp +// row count must stay < 32768 (~q_per_warp * kTopK at max sink), which +// holds for all task.md sizes up to 1024K. +template +__global__ void k2q_scatter_kernel( + int const* __restrict__ q2k, + int const* __restrict__ cu_q, + int const* __restrict__ row_map, + int const* __restrict__ abs_base, + int* __restrict__ q_idx, + int* __restrict__ qsplit_idx, + int* __restrict__ split_counts, + int H, int B, int S_Q, + int total_rows, int max_kv_blocks, + int q_per_cta, int q_per_warp, + int max_seqlen_q) +{ + constexpr int kQPerIter = kWarpSize / kTopK > 0 ? kWarpSize / kTopK : 1; + extern __shared__ int smem_cursor_int[]; + int* smem_cursor = smem_cursor_int; + int tid = threadIdx.x; + int warp_id = tid >> 5; + int lane = tid & 31; + int c = blockIdx.x; + int q_start_cta = c * q_per_cta; + int q_end_cta = min(q_start_cta + q_per_cta, S_Q); + int q_start_warp = min(q_start_cta + warp_id * q_per_warp, q_end_cta); + int q_end_warp = min(q_start_warp + q_per_warp, q_end_cta); + + int q_in_iter = lane / kTopK; + int slot_in_q = lane % kTopK; + bool lane_active = (lane < kQPerIter * kTopK); + + // Per-warp packed cursor: total_rows int16 entries -> ceil(total_rows/2) int32. + int packed_per_warp = (total_rows + 1) >> 1; + int* my_cursor = smem_cursor + warp_id * packed_per_warp; + + for (int h = 0; h < H; ++h) { + for (int i = lane; i < packed_per_warp; i += kWarpSize) my_cursor[i] = 0; + __syncwarp(); + + if (q_start_warp < q_end_warp) { + int bi = 0; + advance_batch_only(cu_q, B, q_start_warp, bi); + + int const* head_q2k = q2k + (size_t)h * S_Q * kTopK; + int const* my_abs_base = + abs_base + ((size_t)(c * kWarps + warp_id) * H + h) * total_rows; + int* head_qidx = q_idx + (size_t)h * S_Q * kTopK; + + // (Hot-row register cache experiment showed no measurable + // benefit; relying on L1 to keep row 0 / row total_rows-1 + // hot since they're hit every iteration in sink workloads.) + + constexpr int kUnroll = 16; + int qi_base = q_start_warp; + for (; qi_base + kUnroll * kQPerIter <= q_end_warp; + qi_base += kUnroll * kQPerIter) { + int kvb[kUnroll]; + int qloc[kUnroll]; + int batch[kUnroll]; + int const* rmap[kUnroll]; + + #pragma unroll + for (int u = 0; u < kUnroll; ++u) { + int qi_u = qi_base + u * kQPerIter + q_in_iter; + kvb[u] = -1; + qloc[u] = 0; + batch[u] = 0; + if (lane_active) { + advance_batch_only(cu_q, B, qi_u, bi); + qloc[u] = qi_u - cu_q[bi]; + batch[u] = bi; + kvb[u] = head_q2k[(size_t)qi_u * kTopK + slot_in_q]; + } + rmap[u] = row_map + (size_t)bi * max_kv_blocks; + } + + int row[kUnroll]; + #pragma unroll + for (int u = 0; u < kUnroll; ++u) { + row[u] = -1; + if (lane_active && kvb[u] >= 0 && kvb[u] < max_kv_blocks) + row[u] = rmap[u][kvb[u]]; + } + + // Pre-issue all kUnroll abs_base loads in parallel before + // the atomic chain so memory pipeline runs concurrently + // with SMEM atomic-adds. + int abs_v[kUnroll]; + #pragma unroll + for (int u = 0; u < kUnroll; ++u) { + abs_v[u] = (row[u] >= 0 && row[u] < total_rows) + ? my_abs_base[row[u]] : 0; + } + + #pragma unroll + for (int u = 0; u < kUnroll; ++u) { + int r = row[u]; + bool valid_edge = r >= 0 && r < total_rows; + unsigned int valid_mask = __ballot_sync(0xFFFFFFFFu, valid_edge); + unsigned int group_mask = (kTopK == 32) + ? 0xFFFFFFFFu + : (((1u << kTopK) - 1u) << (q_in_iter * kTopK)); + unsigned int lower_lane_mask = lane == 0 ? 0u : ((1u << lane) - 1u); + int split_slot = __popc(valid_mask & group_mask & lower_lane_mask); + int valid_count = __popc(valid_mask & group_mask); + if (split_counts != nullptr && slot_in_q == 0) { + int q_abs = cu_q[batch[u]] + qloc[u]; + split_counts[(size_t)q_abs * H + h] = valid_count; + } + if (valid_edge) { + int slot = atomic_inc_int16_packed(my_cursor, r); + int out_pos = abs_v[u] + slot; + head_qidx[out_pos] = qloc[u]; + if (qsplit_idx != nullptr) { + qsplit_idx[(size_t)h * S_Q * kTopK + out_pos] = + qloc[u] | ((split_slot & 0xFF) << 24); + } + } + } + } + // Tail: 1-3 iters left. + for (; qi_base < q_end_warp; qi_base += kQPerIter) { + int my_qi = qi_base + q_in_iter; + bool valid_q = (my_qi < q_end_warp) && lane_active; + int kvb_local = -1; + int q_local = 0; + int batch_local = 0; + if (valid_q) { + advance_batch_only(cu_q, B, my_qi, bi); + batch_local = bi; + q_local = my_qi - cu_q[bi]; + kvb_local = head_q2k[(size_t)my_qi * kTopK + slot_in_q]; + } + int const* my_row_map = row_map + (size_t)bi * max_kv_blocks; + int row = -1; + if (valid_q && kvb_local >= 0 && kvb_local < max_kv_blocks) { + row = my_row_map[kvb_local]; + } + bool valid_edge = row >= 0 && row < total_rows; + unsigned int valid_mask = __ballot_sync(0xFFFFFFFFu, valid_edge); + unsigned int group_mask = (kTopK == 32) + ? 0xFFFFFFFFu + : (((1u << kTopK) - 1u) << (q_in_iter * kTopK)); + unsigned int lower_lane_mask = lane == 0 ? 0u : ((1u << lane) - 1u); + int split_slot = __popc(valid_mask & group_mask & lower_lane_mask); + int valid_count = __popc(valid_mask & group_mask); + if (split_counts != nullptr && valid_q && slot_in_q == 0) { + split_counts[(size_t)my_qi * H + h] = valid_count; + } + if (valid_edge) { + int slot = atomic_inc_int16_packed(my_cursor, row); + int out_pos = my_abs_base[row] + slot; + head_qidx[out_pos] = q_local; + if (qsplit_idx != nullptr) { + qsplit_idx[(size_t)h * S_Q * kTopK + out_pos] = + q_local | ((split_slot & 0xFF) << 24); + } + } + } + } + if (h + 1 < H) __syncthreads(); + } +} + +} // anonymous namespace + +// =========================================================================== +// Host orchestration +// =========================================================================== + +template +static void launch_pipeline( + torch::Tensor q2k, + torch::Tensor cu_q, + torch::Tensor cu_k, + torch::Tensor row_ptr, + torch::Tensor q_idx, + int total_rows, + int max_kv_blocks, + torch::Tensor scheduler_metadata = torch::Tensor(), + torch::Tensor work_count = torch::Tensor(), + torch::Tensor qsplit_idx = torch::Tensor(), + torch::Tensor split_counts = torch::Tensor(), + int target_q_per_cta = 1, + int work_capacity = 0, + int max_seqlen_q = 0) +{ + int H = (int)q2k.size(0); + int S_Q = (int)q2k.size(1); + int topK = (int)q2k.size(2); + TORCH_CHECK(topK == kTopK, "topK runtime != template kTopK"); + int B = (int)cu_q.size(0) - 1; + auto device = q2k.device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + AT_CUDA_CHECK(cudaMemsetAsync( + row_ptr.data_ptr(), 0, + (size_t)H * (total_rows + 1) * sizeof(int), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + q_idx.data_ptr(), 0xFF, + (size_t)H * S_Q * kTopK * sizeof(int), stream)); + + auto opts = torch::TensorOptions().dtype(torch::kInt32).device(device); + auto row_counts = torch::zeros({H, total_rows}, opts); + auto row_map = torch::empty({B, max_kv_blocks}, opts); + bool emit_schedule = scheduler_metadata.defined(); + auto row_coords = emit_schedule ? torch::empty({total_rows, 2}, opts) : torch::Tensor(); + int* scheduler_metadata_ptr = emit_schedule ? scheduler_metadata.data_ptr() : nullptr; + int* work_count_ptr = emit_schedule ? work_count.data_ptr() : nullptr; + int* qsplit_idx_ptr = emit_schedule ? qsplit_idx.data_ptr() : nullptr; + int* split_counts_ptr = emit_schedule ? split_counts.data_ptr() : nullptr; + int* row_coords_ptr = emit_schedule ? row_coords.data_ptr() : nullptr; + if (emit_schedule) { + AT_CUDA_CHECK(cudaMemsetAsync(work_count_ptr, 0, sizeof(int), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + scheduler_metadata_ptr, 0, + (size_t)work_capacity * 6 * sizeof(int), stream)); + } + + int dev = q2k.get_device(); + int num_sms = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute( + &num_sms, cudaDevAttrMultiProcessorCount, dev)); + int max_smem_per_block = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + if (max_smem_per_block <= 0) { + AT_CUDA_CHECK(cudaDeviceGetAttribute( + &max_smem_per_block, cudaDevAttrMaxSharedMemoryPerBlock, dev)); + } + int max_smem_per_sm = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute( + &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev)); + if (max_smem_per_sm <= 0) max_smem_per_sm = max_smem_per_block; + + // -- Pick kWarps per CTA based on device SMEM budget for cursor/hist. + // SMEM cursor is packed as int16 (2 rows per int32 word). + int per_warp_smem = ((total_rows + 1) >> 1) * (int)sizeof(int); + TORCH_CHECK(per_warp_smem <= max_smem_per_block, + "k2q CSR row count exceeds dynamic shared memory limit"); + int kWarps_pick = 4; + while (kWarps_pick > 1 && + kWarps_pick * per_warp_smem > max_smem_per_block) { + kWarps_pick >>= 1; + } + while (kWarps_pick > 1 && + (kWarps_pick * per_warp_smem) * 2 > max_smem_per_sm) { + kWarps_pick >>= 1; + } + if (kWarps_pick < 1) kWarps_pick = 1; + + // -- Pick G (CTAs) ---------------------------------------------------- + // Size one resident wave from actual SM12x shared-memory attributes. + int per_cta_smem_bytes = kWarps_pick * per_warp_smem; + int max_ctas_per_sm = std::max( + 1, max_smem_per_sm / std::max(1, per_cta_smem_bytes)); + if (max_ctas_per_sm > 8) max_ctas_per_sm = 8; + constexpr int kMinQPerCta = 256; + // Cap target_g at num_sms * 3 - empirically this balances + // per-CTA work-size against parallelism. Higher caps regress + // mid-size cases due to row_counts atomicAdd contention and + // smaller q_per_cta. SMEM-bound configurations naturally cap + // lower if max_ctas_per_sm < 3. + int target_g = num_sms * std::min(max_ctas_per_sm, 3); + int max_g_for_q = (S_Q + kMinQPerCta - 1) / kMinQPerCta; + int G = std::min({target_g, max_g_for_q, S_Q}); + if (G < 1) G = 1; + constexpr int kPackedCounterLimit = 32767; + int max_q_per_warp_safe = std::max(1, kPackedCounterLimit / kTopK); + int max_q_per_cta_safe = std::max(1, max_q_per_warp_safe * kWarps_pick); + int min_g_for_counter = (S_Q + max_q_per_cta_safe - 1) / max_q_per_cta_safe; + G = std::max(G, min_g_for_counter); + G = std::min(G, S_Q); + int q_per_cta = (S_Q + G - 1) / G; + G = (S_Q + q_per_cta - 1) / q_per_cta; + int q_per_warp = (q_per_cta + kWarps_pick - 1) / kWarps_pick; + TORCH_CHECK(q_per_warp * kTopK <= kPackedCounterLimit, + "k2q CSR per-warp counter would overflow packed int16 storage"); + int G_total = G * kWarps_pick; + + auto tile_counts = torch::empty({G_total, H, total_rows}, opts); + + // -- Compile-time switch on kWarps for the templated kernels --------- + auto rmap_fn = k2q_build_row_map_kernel; + auto rprefix_fn = k2q_row_prefix_kernel<1024>; + constexpr int kPtRowsPerBlock = 8; + constexpr int kPtThreads = 256; + auto tprefix_smem_fn = k2q_tile_prefix_smem_kernel; + + if (max_kv_blocks > 0) { + rmap_fn<<>>( + cu_k.data_ptr(), row_map.data_ptr(), row_coords_ptr, B, max_kv_blocks); + } + + auto launch_hist_scatter = [&](auto kWarps_const) { + constexpr int W = decltype(kWarps_const)::value; + size_t smem_bytes = (size_t)W * per_warp_smem; + auto hist_fn = k2q_hist_kernel; + auto scat_fn = k2q_scatter_kernel; + AT_CUDA_CHECK(cudaFuncSetAttribute( + hist_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem_bytes)); + AT_CUDA_CHECK(cudaFuncSetAttribute( + scat_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, (int)smem_bytes)); + + hist_fn<<>>( + q2k.data_ptr(), cu_q.data_ptr(), row_map.data_ptr(), + row_counts.data_ptr(), tile_counts.data_ptr(), + H, B, S_Q, total_rows, max_kv_blocks, q_per_cta, q_per_warp); + + rprefix_fn<<>>( + row_counts.data_ptr(), row_ptr.data_ptr(), + emit_schedule ? row_coords.data_ptr() : nullptr, + scheduler_metadata_ptr, + work_count_ptr, + total_rows, + target_q_per_cta, + work_capacity); + + // Grid is H * blocks_per_h so each block stays within a single + // head; flat (H*total_rows) grid would skip rows when total_rows + // is not a multiple of kPtRowsPerBlock. + int blocks_per_h = (total_rows + kPtRowsPerBlock - 1) / kPtRowsPerBlock; + int pt_grid = H * blocks_per_h; + if (pt_grid < 1) pt_grid = 1; + size_t pt_smem = (size_t)kPtRowsPerBlock * G_total * sizeof(int); + AT_CUDA_CHECK(cudaFuncSetAttribute( + tprefix_smem_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, + (int)pt_smem)); + tprefix_smem_fn<<>>( + tile_counts.data_ptr(), row_ptr.data_ptr(), + H, total_rows, G_total); + + scat_fn<<>>( + q2k.data_ptr(), cu_q.data_ptr(), row_map.data_ptr(), + tile_counts.data_ptr(), q_idx.data_ptr(), + qsplit_idx_ptr, split_counts_ptr, + H, B, S_Q, total_rows, max_kv_blocks, q_per_cta, q_per_warp, + max_seqlen_q); + }; + + if (kWarps_pick == 4) { + launch_hist_scatter(std::integral_constant{}); + } else if (kWarps_pick == 2) { + launch_hist_scatter(std::integral_constant{}); + } else { + launch_hist_scatter(std::integral_constant{}); + } +} + +void run_build_k2q_csr( + torch::Tensor q2k, + torch::Tensor cu_q, + torch::Tensor cu_k, + torch::Tensor row_ptr, + torch::Tensor q_idx, + int64_t topk, + int64_t blk_kv, + int64_t total_rows, + int64_t max_kv_blocks) +{ + CHECK_INPUT(q2k); + CHECK_INPUT(cu_q); + CHECK_INPUT(cu_k); + CHECK_INPUT(row_ptr); + CHECK_INPUT(q_idx); + TORCH_CHECK(blk_kv == 128, "build_k2q_csr only supports blk_kv == 128"); + int H = (int)q2k.size(0); + int S_Q = (int)q2k.size(1); + int tr = (int)total_rows; + int mkv = (int)max_kv_blocks; + TORCH_CHECK(tr >= 0 && mkv >= 0, + "total_rows / max_kv_blocks must be non-negative"); + TORCH_CHECK(row_ptr.size(0) == H && row_ptr.size(1) == tr + 1, + "row_ptr shape mismatch"); + TORCH_CHECK(q_idx.size(0) == H && q_idx.size(1) == (int64_t)S_Q * (int)topk, + "q_idx shape mismatch"); + if (S_Q == 0 || tr == 0 || H == 0 || mkv == 0) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync( + row_ptr.data_ptr(), 0, + (size_t)H * (tr + 1) * sizeof(int), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + q_idx.data_ptr(), 0xFF, + (size_t)H * S_Q * (int)topk * sizeof(int), stream)); + return; + } + + if (topk == 16) { + launch_pipeline<16, 128>(q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv); + } else if (topk == 8) { + launch_pipeline<8, 128>(q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv); + } else if (topk == 32) { + launch_pipeline<32, 128>(q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv); + } else if (topk == 4) { + launch_pipeline<4, 128>(q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv); + } else { + TORCH_CHECK(false, "unsupported topK ", topk, " (expected 4, 8, 16, or 32)"); + } +} + +void run_build_k2q_csr_with_schedule( + torch::Tensor q2k, + torch::Tensor cu_q, + torch::Tensor cu_k, + torch::Tensor row_ptr, + torch::Tensor q_idx, + torch::Tensor scheduler_metadata, + torch::Tensor work_count, + torch::Tensor qsplit_idx, + torch::Tensor split_counts, + int64_t topk, + int64_t blk_kv, + int64_t total_rows, + int64_t max_kv_blocks, + int64_t target_q_per_cta, + int64_t work_capacity, + int64_t max_seqlen_q) +{ + CHECK_INPUT(q2k); + CHECK_INPUT(cu_q); + CHECK_INPUT(cu_k); + CHECK_INPUT(row_ptr); + CHECK_INPUT(q_idx); + CHECK_INPUT(scheduler_metadata); + CHECK_INPUT(work_count); + CHECK_INPUT(qsplit_idx); + CHECK_INPUT(split_counts); + TORCH_CHECK(blk_kv == 128, "build_k2q_csr only supports blk_kv == 128"); + int H = (int)q2k.size(0); + int S_Q = (int)q2k.size(1); + int tr = (int)total_rows; + int mkv = (int)max_kv_blocks; + int target = (int)target_q_per_cta; + int capacity = (int)work_capacity; + int max_sq = (int)max_seqlen_q; + TORCH_CHECK(tr >= 0 && mkv >= 0 && target > 0 && capacity > 0 && max_sq >= 0, + "invalid schedule sizing arguments"); + TORCH_CHECK(row_ptr.size(0) == H && row_ptr.size(1) == tr + 1, + "row_ptr shape mismatch"); + TORCH_CHECK(q_idx.size(0) == H && q_idx.size(1) == (int64_t)S_Q * (int)topk, + "q_idx shape mismatch"); + TORCH_CHECK(qsplit_idx.sizes() == q_idx.sizes(), "qsplit_idx shape mismatch"); + TORCH_CHECK(scheduler_metadata.size(0) == capacity && scheduler_metadata.size(1) == 6, + "scheduler_metadata shape mismatch"); + TORCH_CHECK(work_count.numel() == 1, "work_count must have one int32 element"); + TORCH_CHECK(split_counts.dim() == 2 && split_counts.size(0) == S_Q + && split_counts.size(1) == H, + "split_counts shape mismatch"); + if (S_Q == 0 || tr == 0 || H == 0 || mkv == 0) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + AT_CUDA_CHECK(cudaMemsetAsync( + row_ptr.data_ptr(), 0, + (size_t)H * (tr + 1) * sizeof(int), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + q_idx.data_ptr(), 0xFF, + (size_t)H * S_Q * (int)topk * sizeof(int), stream)); + AT_CUDA_CHECK(cudaMemsetAsync(work_count.data_ptr(), 0, sizeof(int), stream)); + if (split_counts.numel() > 0) { + AT_CUDA_CHECK(cudaMemsetAsync( + split_counts.data_ptr(), 0, + (size_t)split_counts.numel() * sizeof(int), stream)); + } + return; + } + + if (topk == 16) { + launch_pipeline<16, 128>( + q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv, + scheduler_metadata, work_count, qsplit_idx, split_counts, + target, capacity, max_sq); + } else if (topk == 8) { + launch_pipeline<8, 128>( + q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv, + scheduler_metadata, work_count, qsplit_idx, split_counts, + target, capacity, max_sq); + } else if (topk == 32) { + launch_pipeline<32, 128>( + q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv, + scheduler_metadata, work_count, qsplit_idx, split_counts, + target, capacity, max_sq); + } else if (topk == 4) { + launch_pipeline<4, 128>( + q2k, cu_q, cu_k, row_ptr, q_idx, tr, mkv, + scheduler_metadata, work_count, qsplit_idx, split_counts, + target, capacity, max_sq); + } else { + TORCH_CHECK(false, "unsupported topK ", topk, " (expected 4, 8, 16, or 32)"); + } +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("run_build_k2q_csr", &run_build_k2q_csr, + "q2k -> k2q CSR build (sorted within row)", + pybind11::arg("q2k"), + pybind11::arg("cu_q"), + pybind11::arg("cu_k"), + pybind11::arg("row_ptr"), + pybind11::arg("q_idx"), + pybind11::arg("topk"), + pybind11::arg("blk_kv"), + pybind11::arg("total_rows"), + pybind11::arg("max_kv_blocks")); + m.def("run_build_k2q_csr_with_schedule", &run_build_k2q_csr_with_schedule, + "q2k -> k2q CSR build with fused attention schedule metadata", + pybind11::arg("q2k"), + pybind11::arg("cu_q"), + pybind11::arg("cu_k"), + pybind11::arg("row_ptr"), + pybind11::arg("q_idx"), + pybind11::arg("scheduler_metadata"), + pybind11::arg("work_count"), + pybind11::arg("qsplit_idx"), + pybind11::arg("split_counts"), + pybind11::arg("topk"), + pybind11::arg("blk_kv"), + pybind11::arg("total_rows"), + pybind11::arg("max_kv_blocks"), + pybind11::arg("target_q_per_cta"), + pybind11::arg("work_capacity"), + pybind11::arg("max_seqlen_q")); +} diff --git a/python/fmha_sm12x/cute/src/sm12x/decode_schedule.py b/python/fmha_sm12x/cute/src/sm12x/decode_schedule.py new file mode 100644 index 0000000..cd78e34 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/decode_schedule.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: MIT + +"""Split-KV schedule for paged fp8 decode attention. + +The public PageKV representation remains this repo's rectangular page table: +``page_table [B, max_pages]`` plus ``seqused_k [B]``. The schedule only +describes how query tiles and KV chunks are split into work items. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Optional + +import torch + + +@dataclass +class DecodeAttentionSchedule: + split_kv: bool + cta_tile_q: int + num_q_tiles: int + kv_chunk_size_pages: int + kv_chunk_size_tokens: int + work_count: int + padded_work_count: int + partial_rows: int + max_split_count: int + max_grid_size: int + active_blocks_per_sm: int + num_sms: int + base_cta: int + request_indices: torch.Tensor + qo_tile_indices: torch.Tensor + kv_tile_indices: torch.Tensor + merge_indptr: torch.Tensor + o_indptr: torch.Tensor + block_valid_mask: torch.Tensor + kv_pages: torch.Tensor + split_counts: torch.Tensor + + +def _require_i32_cuda_1d(tensor: torch.Tensor, *, name: str) -> None: + if tensor.dtype != torch.int32: + raise TypeError(f"{name} must be torch.int32") + if tensor.ndim != 1: + raise ValueError(f"{name} must be rank-1") + if not tensor.is_cuda: + raise ValueError(f"{name} must be a CUDA tensor") + if not tensor.is_contiguous(): + raise ValueError(f"{name} must be contiguous") + + +def prepare_decode_schedule( + *, + seqused_k: torch.Tensor, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: Optional[int] = None, + fixed_split_size: Optional[int] = None, + disable_split_kv: bool = False, +) -> DecodeAttentionSchedule: + """Build paged decode split-KV schedule on the GPU. + + A single CUDA kernel reads ``seqused_k`` on device and writes all + schedule index arrays. Only a small summary tensor is D2H-synced so + the wrapper can size O_partial / pick the kernel grid / choose the + split-vs-non-split compile path. + + ``max_seqlen_k`` is the host-side worst-case bound used to pad the + work-tile arrays. It must satisfy ``max(seqused_k) <= max_seqlen_k``. + """ + _require_i32_cuda_1d(seqused_k, name="seqused_k") + # The single-CTA scheduler keeps per-batch state in shared memory, so + # batch must not exceed 1024; larger batches need a multi-CTA design. + if int(seqused_k.shape[0]) > 1024: + raise NotImplementedError( + "decode schedule supports batch <= 1024 " + f"(got batch={int(seqused_k.shape[0])}); larger batches require " + "a multi-CTA scheduler." + ) + # The two device-data-dependent hang guards (seqused_k[b] >= seqlen_q and + # the last-partial-page column count) are enforced inside the raw + # build_decode_schedule() launch wrapper so that direct callers of the raw + # entrypoint are protected too; see _validate_decode_seqused_k there. + if int(page_size) <= 0: + raise ValueError("page_size must be positive") + if int(seqlen_q) <= 0: + raise ValueError("seqlen_q must be positive") + if int(num_qo_heads) <= 0 or int(num_kv_heads) <= 0: + raise ValueError("head counts must be positive") + if int(num_qo_heads) % int(num_kv_heads) != 0: + raise ValueError("num_qo_heads must be divisible by num_kv_heads") + if int(num_qo_heads) // int(num_kv_heads) != 16: + raise NotImplementedError("decode schedule currently supports only qhead_per_kv=16") + if int(head_dim) != 128: + raise NotImplementedError("decode schedule currently supports only head_dim=128") + if int(max_seqlen_k) <= 0: + raise ValueError("max_seqlen_k must be positive") + # max(seqused_k) <= max_seqlen_k is enforced in the raw build_decode_schedule + # wrapper (single source of truth; protects direct raw callers too). + + from .fwd_decode.build_decode_schedule import build_decode_schedule + + raw = build_decode_schedule( + seqused_k, + page_size=int(page_size), + seqlen_q=int(seqlen_q), + num_qo_heads=int(num_qo_heads), + num_kv_heads=int(num_kv_heads), + head_dim=int(head_dim), + max_seqlen_k=int(max_seqlen_k), + enable_cuda_graph=bool(enable_cuda_graph), + max_grid_size=0 if max_grid_size is None else int(max_grid_size), + fixed_split_size=-1 if fixed_split_size is None else int(fixed_split_size), + disable_split_kv=bool(disable_split_kv), + ) + return DecodeAttentionSchedule( + split_kv=bool(raw["split_kv"]), + cta_tile_q=int(raw["cta_tile_q"]), + num_q_tiles=int(raw["num_q_tiles"]), + kv_chunk_size_pages=int(raw["kv_chunk_size_pages"]), + kv_chunk_size_tokens=int(raw["kv_chunk_size_tokens"]), + work_count=int(raw["work_count"]), + padded_work_count=int(raw["padded_work_count"]), + partial_rows=int(raw["partial_rows"]), + max_split_count=int(raw["max_split_count"]), + max_grid_size=int(raw["max_grid_size"]), + active_blocks_per_sm=int(raw["active_blocks_per_sm"]), + num_sms=int(raw["num_sms"]), + base_cta=int(raw["base_cta"]), + request_indices=raw["request_indices"], + qo_tile_indices=raw["qo_tile_indices"], + kv_tile_indices=raw["kv_tile_indices"], + merge_indptr=raw["merge_indptr"], + o_indptr=raw["o_indptr"], + block_valid_mask=raw["block_valid_mask"], + kv_pages=raw["kv_pages"], + split_counts=raw["split_counts"], + ) + + +__all__ = [ + "DecodeAttentionSchedule", + "prepare_decode_schedule", +] diff --git a/python/fmha_sm12x/cute/src/sm12x/fwd_decode/__init__.py b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/__init__.py new file mode 100644 index 0000000..4f456d8 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT + +"""SM12x decode helper kernels.""" diff --git a/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/__init__.py b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/__init__.py new file mode 100644 index 0000000..55e7f19 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/__init__.py @@ -0,0 +1,171 @@ +# SPDX-License-Identifier: MIT + +"""JIT-loaded CUDA/C++ extension for SM12x paged decode split-KV scheduling.""" + +from __future__ import annotations + +import os + +import torch +from torch.utils.cpp_extension import load + +from minimax_msa.arch import ( + cpp_extension_arch_flag, + cuda_arch_cache_suffix, + require_sm12x_csrc_arch, +) + +_THIS_DIR = os.path.dirname(os.path.abspath(__file__)) +_SRC = os.path.join(_THIS_DIR, "build_decode_schedule.cu") + +_EXTRA_CFLAGS = ["-O3"] +_EXTRA_CUDA_CFLAGS_BASE = [ + "-O3", + "--use_fast_math", + "-lineinfo", + "--ptxas-options=-v", + "--expt-relaxed-constexpr", +] + +_ext = None + + +def _validate_decode_seqused_k( + seqused_k: torch.Tensor, *, seqlen_q: int, page_size: int, max_seqlen_k: int +) -> None: + """Reject seqused_k values that hang the kernel or overflow the pad. + + These are device-data-dependent guards enforced at the raw launch + boundary so that *every* caller — the high-level + ``prepare_decode_schedule`` wrapper and any direct user of this raw + entrypoint alike — is protected (the C++ kernel only TORCH_CHECKs + structural invariants and would otherwise spin on an all-masked row or + scatter past the worst-case-padded output arrays). + + (0) max(seqused_k) <= max_seqlen_k. The host sizes the work-tile arrays + from max_pages_global = ceil(max_seqlen_k / page_size); a longer + seqused_k produces work_count > pad_work and the wrapper's + narrow(0, 0, padded_work_count) / kernel scatter run out of bounds. + + (1) seqused_k[b] >= seqlen_q. The kernel's causal col_limit for the + first packed q-token is seqlen_k - seqlen_q + 1, which goes <= 0 + when seqlen_k < seqlen_q. That all-masked row hits a mask-codegen + path with PTX-undefined shift counts and the kernel hangs. It is + also a batched-decode invariant: seqlen_k must include the + seqlen_q new tokens being emitted. + + (2) seqused_k[b] % page_size in {0, 8, 16, ..., 120}. The same hang + fires when the last partial page has < q_tokens_per_group=8 valid + columns, because the last MMA tile then hits the all-masked row + case for the trailing q-tokens. + """ + + max_seqlen_k_i = int(max_seqlen_k) + max_used_k = int(seqused_k.max().item()) if seqused_k.numel() > 0 else 0 + if max_used_k > max_seqlen_k_i: + raise ValueError( + f"max_seqlen_k must cover max(seqused_k), got {max_seqlen_k_i} " + f"for max seqused_k {max_used_k}" + ) + seqlen_q_i = int(seqlen_q) + bad_q = seqused_k < seqlen_q_i + if bool(bad_q.any().item()): + bad_idx = int(torch.nonzero(bad_q, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] >= seqlen_q (= {seqlen_q_i}) " + f"for every batch. Got seqused_k[{bad_idx}]={bad_val}. " + f"This is also a batched-decode invariant: seqlen_k must include " + f"the seqlen_q new tokens being emitted." + ) + page_size_i = int(page_size) + rem = seqused_k % page_size_i + bad_rem = (rem > 0) & (rem < seqlen_q_i) + if bool(bad_rem.any().item()): + bad_idx = int(torch.nonzero(bad_rem, as_tuple=True)[0][0].item()) + bad_val = int(seqused_k[bad_idx].item()) + raise ValueError( + f"decode kernel requires seqused_k[b] % page_size in " + f"{{0, {seqlen_q_i}, {seqlen_q_i*2}, ..., {max(page_size_i//seqlen_q_i, 1)*seqlen_q_i}}}. " + f"Got seqused_k[{bad_idx}]={bad_val}, last partial page has " + f"{bad_val % page_size_i} valid columns (< seqlen_q={seqlen_q_i}). " + f"Round seqused_k up to the next multiple of {seqlen_q_i} OR to " + f"a multiple of {page_size_i}." + ) + + +def _cccl_include_flags() -> list[str]: + cuda_home = os.environ.get("CUDA_HOME", "/usr/local/cuda") + cccl = os.path.join(cuda_home, "include", "cccl") + return [f"-I{cccl}"] if os.path.isdir(cccl) else [] + + +def _load_ext(): + global _ext + if _ext is None: + require_sm12x_csrc_arch("fmha_sm12x.decode_schedule") + _ext = load( + name=f"sparse_decode_schedule_sm12x_ext{cuda_arch_cache_suffix()}", + sources=[_SRC], + extra_cflags=_EXTRA_CFLAGS, + extra_cuda_cflags=[ + *_EXTRA_CUDA_CFLAGS_BASE, + cpp_extension_arch_flag(), + *_cccl_include_flags(), + ], + verbose=False, + ) + return _ext + + +def build_decode_schedule( + seqused_k: torch.Tensor, + *, + page_size: int, + seqlen_q: int, + num_qo_heads: int, + num_kv_heads: int, + head_dim: int, + max_seqlen_k: int, + enable_cuda_graph: bool = False, + max_grid_size: int = 0, + fixed_split_size: int = -1, + disable_split_kv: bool = False, +) -> dict[str, object]: + """Build paged decode schedule arrays on device with the SM12x helper.""" + + # Device-data-dependent hang guards enforced here (the lowest common + # launch boundary) so direct callers can't spin the kernel on an + # all-masked row. Skip when the config is non-positive so the C++ + # TORCH_CHECKs surface the clean structural error instead. + if int(seqlen_q) > 0 and int(page_size) > 0: + _validate_decode_seqused_k( + seqused_k, seqlen_q=int(seqlen_q), page_size=int(page_size), + max_seqlen_k=int(max_seqlen_k), + ) + + raw = _load_ext().build_decode_schedule( + seqused_k, + int(page_size), + int(seqlen_q), + int(num_qo_heads), + int(num_kv_heads), + int(head_dim), + int(max_seqlen_k), + bool(enable_cuda_graph), + int(max_grid_size), + int(fixed_split_size), + bool(disable_split_kv), + ) + pad = int(raw["padded_work_count"]) + for key in ( + "request_indices", + "qo_tile_indices", + "kv_tile_indices", + "block_valid_mask", + ): + raw[key] = raw[key].narrow(0, 0, pad) + return raw + + +__all__ = ["build_decode_schedule"] diff --git a/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/build_decode_schedule.cu b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/build_decode_schedule.cu new file mode 100644 index 0000000..fdde024 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/fwd_decode/build_decode_schedule/build_decode_schedule.cu @@ -0,0 +1,665 @@ +// SPDX-License-Identifier: MIT + +#include + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace cg = cooperative_groups; + +namespace py = pybind11; + +namespace { + +constexpr int64_t kTargetDecodeQHeadPerKv = 16; +constexpr int64_t kTargetDecodeHeadDim = 128; +constexpr int64_t kTargetDecodeTileM = 128; +constexpr int64_t kTargetDecodeKvBytes = 1; + +template +T ceil_div(T x, T y) { + return (x + y - 1) / y; +} + +// --------------------------------------------------------------------------- +// GPU schedule kernel: one CTA computes the full decode schedule in-place. +// +// Layout: single block of `tpb` threads (must be >= batch, power-of-two, +// <= 1024). Each thread b handles one batch slot. +// +// Pipeline: +// 1. Per-thread: read seqused_k[b] -> compute kv_pages[b]; write to output. +// 2. Block reduce: max_pages, min_pages, sum_pages. +// 3. Thread 0: decide kv_chunk_size_pages (binary search + variable-kv +// load-balance heuristic) and split_kv flag. +// 4. Per-thread: compute split_counts[b], inclusive prefix-sum of +// work-slots and partial-rows for offsets. +// 5. Block reduce: max_split_count. +// 6. Per-thread: write o_indptr, merge_indptr, request_indices, +// qo_tile_indices, kv_tile_indices, block_valid_mask for batch b's +// slot range. +// 7. Thread 0: write 5-i32 info_scalars = (split_kv, kv_chunk_size_pages, +// padded_work_count, max_split_count, partial_rows). +// +// `padded_work_count_pad` is the host-side worst-case pad for the output +// arrays; the kernel writes block_valid_mask=0 for slots past work_count. +// Caller must ensure tpb >= batch and is a power of two. +// --------------------------------------------------------------------------- +// Warp-shuffle helpers (single-warp fast path). +__device__ __forceinline__ int warp_reduce_max(int v) { + v = max(v, __shfl_xor_sync(0xFFFFFFFFu, v, 1)); + v = max(v, __shfl_xor_sync(0xFFFFFFFFu, v, 2)); + v = max(v, __shfl_xor_sync(0xFFFFFFFFu, v, 4)); + v = max(v, __shfl_xor_sync(0xFFFFFFFFu, v, 8)); + v = max(v, __shfl_xor_sync(0xFFFFFFFFu, v, 16)); + return v; +} +__device__ __forceinline__ int warp_reduce_min(int v) { + v = min(v, __shfl_xor_sync(0xFFFFFFFFu, v, 1)); + v = min(v, __shfl_xor_sync(0xFFFFFFFFu, v, 2)); + v = min(v, __shfl_xor_sync(0xFFFFFFFFu, v, 4)); + v = min(v, __shfl_xor_sync(0xFFFFFFFFu, v, 8)); + v = min(v, __shfl_xor_sync(0xFFFFFFFFu, v, 16)); + return v; +} +__device__ __forceinline__ int warp_reduce_sum(int v) { + v += __shfl_xor_sync(0xFFFFFFFFu, v, 1); + v += __shfl_xor_sync(0xFFFFFFFFu, v, 2); + v += __shfl_xor_sync(0xFFFFFFFFu, v, 4); + v += __shfl_xor_sync(0xFFFFFFFFu, v, 8); + v += __shfl_xor_sync(0xFFFFFFFFu, v, 16); + return v; +} +// Inclusive scan within a warp using shuffle. +__device__ __forceinline__ int warp_inclusive_scan_sum(int v) { + int n = __shfl_up_sync(0xFFFFFFFFu, v, 1); + if ((threadIdx.x & 31) >= 1) v += n; + n = __shfl_up_sync(0xFFFFFFFFu, v, 2); + if ((threadIdx.x & 31) >= 2) v += n; + n = __shfl_up_sync(0xFFFFFFFFu, v, 4); + if ((threadIdx.x & 31) >= 4) v += n; + n = __shfl_up_sync(0xFFFFFFFFu, v, 8); + if ((threadIdx.x & 31) >= 8) v += n; + n = __shfl_up_sync(0xFFFFFFFFu, v, 16); + if ((threadIdx.x & 31) >= 16) v += n; + return v; +} + +__global__ void build_decode_schedule_gpu_kernel( + const int32_t* __restrict__ seqused_k, + int batch, + int page_size, + int seqlen_q, + int num_q_tiles, + int num_kv_heads, + int q_tokens_per_group, + int max_grid_size, + int fixed_split_size, + int disable_split_kv, + int enable_cuda_graph, + int padded_work_count_pad, + // outputs: + int32_t* __restrict__ kv_pages, + int32_t* __restrict__ split_counts, + int32_t* __restrict__ request_indices, + int32_t* __restrict__ qo_tile_indices, + int32_t* __restrict__ kv_tile_indices, + int32_t* __restrict__ block_valid_mask, + int32_t* __restrict__ merge_indptr, + int32_t* __restrict__ o_indptr, + int32_t* __restrict__ info_scalars) { + // Multi-CTA cooperative kernel. CTA 0 (warp 0) does the small + // sequential decision phases (reductions, binary-search chunk pick, + // prefix scan, info_scalars). All CTAs then collaborate on the + // scatter phase via grid-stride loop. grid.sync() between phases. + cg::grid_group grid = cg::this_grid(); + constexpr int kMaxBatch = 1024; + constexpr int kMaxWarps = 32; // tpb<=1024 -> at most 32 warps + __shared__ int s_kv_pages[kMaxBatch]; + __shared__ int s_split_counts[kMaxBatch]; + __shared__ int s_work_slots[kMaxBatch]; // inclusive prefix-sum + __shared__ int s_partial_slots[kMaxBatch]; // inclusive prefix-sum + __shared__ int s_chunk_size; + __shared__ int s_split_kv_flag; + __shared__ int s_work_count_shared; + __shared__ int s_warp_max[kMaxWarps]; + __shared__ int s_warp_min[kMaxWarps]; + __shared__ int s_warp_sum[kMaxWarps]; + __shared__ int s_warp_max_split[kMaxWarps]; + + const int tid = threadIdx.x; + const int tpb = blockDim.x; + const int bid = blockIdx.x; + const int n_ctas = gridDim.x; + const int lane = tid & 31; + const int warp = tid >> 5; + + // Single-CTA design (4 warps = 128 threads). Decision in warp 0 via + // shuffles; scatter across all warps via grid-stride loop within the + // CTA. No grid.sync() needed. + { + // Per-thread: read kv_pages, write to gmem + shmem. + int kv_pages_b = 0; + if (tid < batch) { + int sk = seqused_k[tid]; + kv_pages_b = (sk + page_size - 1) / page_size; + if (kv_pages_b < 1) kv_pages_b = 1; + s_kv_pages[tid] = kv_pages_b; + kv_pages[tid] = kv_pages_b; + } else if (tid < kMaxBatch) { + s_kv_pages[tid] = 0; + } + __syncthreads(); + + // Chunk-size decision. + // + // First, cross-warp reduction of max/min/sum over ALL batches. Each + // thread brings its own batch slot's kv_pages (or sentinel for + // tid >= batch); each warp does a shuffle-based intra-warp reduce + // and writes per-warp partials to shmem; warp 0 then combines the + // warp partials with another shuffle reduce. The two-level reduce is + // required for batch > 32 (more than one warp of per-batch slots). + const int num_warps = tpb >> 5; + int chunk = 0; + int split_kv_flag = 0; + int max_pages = 0; + int min_pages = 0; + int sum_pages = 0; + { + const int active_my_slot = (tid < batch) ? 1 : 0; + int v_for_max = active_my_slot ? kv_pages_b : INT_MIN; + int v_for_min = active_my_slot ? kv_pages_b : INT_MAX; + int v_for_sum = active_my_slot ? kv_pages_b : 0; + int warp_max = warp_reduce_max(v_for_max); + int warp_min = warp_reduce_min(v_for_min); + int warp_sum = warp_reduce_sum(v_for_sum); + if (lane == 0) { + s_warp_max[warp] = warp_max; + s_warp_min[warp] = warp_min; + s_warp_sum[warp] = warp_sum; + } + } + __syncthreads(); + if (warp == 0) { + int lv_max = (lane < num_warps) ? s_warp_max[lane] : INT_MIN; + int lv_min = (lane < num_warps) ? s_warp_min[lane] : INT_MAX; + int lv_sum = (lane < num_warps) ? s_warp_sum[lane] : 0; + max_pages = warp_reduce_max(lv_max); + min_pages = warp_reduce_min(lv_min); + sum_pages = warp_reduce_sum(lv_sum); + } + if (warp == 0) { + // Helper: compute work_x = sum_b(ceil(kv_pages_b / chunk_size)) * num_q_tiles + // across ALL batches using lane-parallel iteration in groups of 32. + // Each iteration of the outer loop covers 32 batches; warp_reduce_sum + // gives the partial; accumulate into work_x in lockstep across lanes. + auto compute_work_x = [&](int chunk_size) -> int { + int sum = 0; + for (int b_base = 0; b_base < batch; b_base += 32) { + int b = b_base + lane; + int kvp = (b < batch) ? s_kv_pages[b] : 0; + int c_x = (b < batch) ? ((kvp + chunk_size - 1) / chunk_size) : 0; + sum += warp_reduce_sum(c_x); + } + return sum * num_q_tiles; + }; + + int base_work_count = batch * num_q_tiles; + int base_cta = base_work_count * num_kv_heads; + // Split-KV thresholds. Splitting a batch into KV chunks adds a combine + // pass (extra O_partial / LSE_partial fp32 writes, ~5-10us); below these + // sizes that overhead outweighs the added parallelism: + // kMinUsefulChunkPages = 16 chunk floor (~2K tokens); smaller chunks + // fall below the combine break-even. + // kTinyKvNoSplitPages = 8 max kv_pages <= 8 (<= 1K tokens) never + // splits - the whole KV fits inside one + // combine-overhead window. + constexpr int kMinUsefulChunkPages = 16; + constexpr int kTinyKvNoSplitPages = 8; + int min_chunk_pages_floor = (128 / page_size); + if (min_chunk_pages_floor < 1) min_chunk_pages_floor = 1; + const int min_chunk_pages = max(min_chunk_pages_floor, kMinUsefulChunkPages); + if (disable_split_kv != 0) { + chunk = max_pages; + split_kv_flag = 0; + } else if (fixed_split_size > 0) { + chunk = max(fixed_split_size, 1); + int work_x = compute_work_x(chunk); + split_kv_flag = (work_x != base_work_count) ? 1 : 0; + } else if (base_cta >= max_grid_size) { + chunk = max_pages; + split_kv_flag = 0; + } else if (max_pages <= kTinyKvNoSplitPages) { + // KV per batch is too short for split to pay back the combine + // overhead. Skip the binary search entirely. + chunk = max_pages; + split_kv_flag = 0; + } else { + int low = min(min_chunk_pages, max_pages); + int high = max_pages; + while (low < high) { + int mid = (low + high) >> 1; + int work_x = compute_work_x(mid); + if (work_x * num_kv_heads > max_grid_size) low = mid + 1; + else high = mid; + } + chunk = low; + // Variable-kv load-balance override. When kv-lengths span a wide + // range (one long batch among many short ones), the binary search + // picks a `chunk` that just fills the grid, leaving the long batch + // serial on one CTA while short batches finish and idle their SMs. + // For an imbalanced batch (max/avg >= 1.5) use a smaller chunk + // (~avg_pages/4 splits per average batch, still >= the 2K-token + // floor) so the long batch breaks into more parallel slots; the + // avg/4 target matches the 1-CTA/SM attn kernel. avg_pages >= 4 + // guards against tiny kv where splits would dominate. + int avg_pages = (sum_pages + batch - 1) / batch; + if (max_pages * 2 >= avg_pages * 3 && avg_pages >= 4) { + int balance_chunk = max(min_chunk_pages, avg_pages >> 2); + if (max_pages > balance_chunk && balance_chunk < chunk) { + chunk = balance_chunk; + } + } + // The combine kernel caps max_splits at 256 (bounded by its sLSE smem + // allocation and per-thread LSE-reduction registers). Round `chunk` + // up so no batch ever produces more than 256 splits; this only costs + // parallelism for ultra-long context (the chunk floor rises past + // kv ~ 512K). + constexpr int kCombineMaxSplits = 256; + int min_chunk_for_combine = (max_pages + kCombineMaxSplits - 1) / + kCombineMaxSplits; + if (chunk < min_chunk_for_combine) chunk = min_chunk_for_combine; + int work_x = compute_work_x(chunk); + split_kv_flag = (enable_cuda_graph != 0 || work_x != base_work_count) ? 1 : 0; + if (split_kv_flag == 0) chunk = max_pages; + } + if (lane == 0) { + s_chunk_size = chunk; + s_split_kv_flag = split_kv_flag; + } + } + __syncthreads(); + + // Compute split_counts + parallel inclusive scan. + int chunks_b = 0; + int work_slots_b = 0; + int partial_slots_b = 0; + if (tid < batch) { + chunks_b = s_split_kv_flag + ? ((s_kv_pages[tid] + s_chunk_size - 1) / s_chunk_size) + : 1; + work_slots_b = chunks_b * num_q_tiles; + partial_slots_b = chunks_b * num_q_tiles * q_tokens_per_group; + s_split_counts[tid] = chunks_b; + split_counts[tid] = chunks_b; + } else if (tid < kMaxBatch) { + s_split_counts[tid] = 0; + } + + // Inclusive scan via warp shuffle for batch <= 32. For batch > 32, + // fall back to Hillis-Steele in shared memory (rare in production). + if (batch <= 32) { + int inc_w = (tid < batch) ? work_slots_b : 0; + int inc_p = (tid < batch) ? partial_slots_b : 0; + if (warp == 0) { + inc_w = warp_inclusive_scan_sum(inc_w); + inc_p = warp_inclusive_scan_sum(inc_p); + } + if (tid < batch) { + s_work_slots[tid] = inc_w; + s_partial_slots[tid] = inc_p; + } + } else { + s_work_slots[tid] = (tid < batch) ? work_slots_b : 0; + s_partial_slots[tid] = (tid < batch) ? partial_slots_b : 0; + __syncthreads(); + for (int off = 1; off < tpb; off <<= 1) { + int w_add = (tid >= off) ? s_work_slots[tid - off] : 0; + int p_add = (tid >= off) ? s_partial_slots[tid - off] : 0; + __syncthreads(); + s_work_slots[tid] += w_add; + s_partial_slots[tid] += p_add; + __syncthreads(); + } + } + __syncthreads(); + + // max_split_count via cross-warp reduce (covers batch > 32, where a + // single-warp reduce would undercount). + int local_max_split = (tid < batch) ? chunks_b : INT_MIN; + { + int warp_max = warp_reduce_max(local_max_split); + if (lane == 0) s_warp_max_split[warp] = warp_max; + } + __syncthreads(); + int max_split_count_local = 0; + if (warp == 0) { + int v = (lane < num_warps) ? s_warp_max_split[lane] : INT_MIN; + max_split_count_local = warp_reduce_max(v); + if (max_split_count_local < 1) max_split_count_local = 1; + } + + // Write o_indptr and merge_indptr in parallel. + // o_indptr[0] = 0, o_indptr[b+1] = inclusive_partial[b] + // merge_indptr[0] = 0 + // merge_indptr[b * seqlen_q + q + 1] = + // exclusive_prefix_chunks[b] * seqlen_q + (q + 1) * chunks[b] + // exclusive_prefix_chunks[b] = (work_slots inclusive scan / num_q_tiles) - chunks[b] + // ... actually simpler: chunks-prefix == work_slots-prefix / num_q_tiles, since + // work_slots_b = chunks_b * num_q_tiles. When num_q_tiles==1, they're equal. + if (tid == 0) { + o_indptr[0] = 0; + merge_indptr[0] = 0; + } + if (tid < batch) { + o_indptr[tid + 1] = s_partial_slots[tid]; + // Compute exclusive prefix sum of chunks for THIS batch tid. + int incl_chunks = s_work_slots[tid] / max(num_q_tiles, 1); + int excl_chunks = incl_chunks - chunks_b; + // Parallel per-q write within this batch slot. + for (int q = 0; q < seqlen_q; ++q) { + merge_indptr[tid * seqlen_q + q + 1] = + excl_chunks * seqlen_q + (q + 1) * chunks_b; + } + } + + // Write info_scalars + s_work_count_shared so the scatter + // phase across other warps can read it. Thread (batch-1) holds the + // inclusive scan total; broadcast via shared mem. + if (tid == batch - 1) { + s_work_count_shared = s_work_slots[batch - 1]; + } else if (tid == 0 && batch == 0) { + s_work_count_shared = 0; + } + __syncthreads(); + int work_count = s_work_count_shared; + int partial_rows = (batch > 0) ? s_partial_slots[batch - 1] : 0; + if (warp == 0 && lane == 0) { + int padded_wc = (enable_cuda_graph != 0 && s_split_kv_flag != 0) + ? max(work_count, max(1, max_grid_size / num_kv_heads)) + : work_count; + info_scalars[0] = s_split_kv_flag; + info_scalars[1] = s_chunk_size; + info_scalars[2] = padded_wc; + info_scalars[3] = max_split_count_local; + info_scalars[4] = partial_rows; + } + } + + // Sync so warps 1-3 see the shared-memory state written by warp 0 in + // Part A (s_split_counts, s_work_slots, s_work_count_shared). + __syncthreads(); + + // Scatter (all 128 threads, intra-CTA grid-stride loop). + // Note: s_split_counts and s_work_slots are valid in shared mem from + // Part A. Use them directly (no global reload). + int work_count_total = s_work_count_shared; + for (int idx = tid; idx < work_count_total; idx += tpb) { + // Inverse map idx -> (b, q_tile, kv_tile) via linear search. + int b_found = 0; + int prev_prefix = 0; + for (int j = 0; j < batch; ++j) { + int p = s_split_counts[j] * num_q_tiles; + if (idx < prev_prefix + p) { b_found = j; break; } + prev_prefix += p; + } + int within = idx - prev_prefix; + int chunks_at_b = s_split_counts[b_found]; + int q_tile = (chunks_at_b > 0) ? (within / chunks_at_b) : 0; + int kv_tile = within - q_tile * chunks_at_b; + request_indices[idx] = b_found; + qo_tile_indices[idx] = q_tile; + kv_tile_indices[idx] = kv_tile; + block_valid_mask[idx] = 1; + } +} + +int64_t determine_cta_tile_q(int64_t packed_q_len, int64_t head_dim, int compute_major) { + if (packed_q_len > 64 && head_dim < 256) { + return 128; + } + if (compute_major >= 8) { + return packed_q_len > 16 ? 64 : 16; + } + return 64; +} + +// Decode attn kernel runs at 1 CTA/SM (UTCMMA + warp specialization +// holds ~240 reg/thread x 512 threads, saturating the register file). +// max_grid_size = num_sms is therefore the exact attainable grid; we +// don't probe occupancy because the CUTE DSL kernel's function pointer +// isn't reachable from C++ (would have required a proxy kernel, whose +// register pressure differs from the real one and gave misleading 8-16 +// blocks/SM estimates that triggered over-splitting at small kv). +std::tuple estimate_decode_grid_size( + int64_t /*num_qo_heads*/, + int64_t /*num_kv_heads*/, + int64_t /*head_dim*/, + int64_t max_grid_size_override) { + int dev_id = 0; + AT_CUDA_CHECK(cudaGetDevice(&dev_id)); + int num_sms = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + if (max_grid_size_override > 0) { + int64_t active_blocks = std::max( + 1, ceil_div(max_grid_size_override, std::max(num_sms, 1))); + return {max_grid_size_override, active_blocks, num_sms}; + } + // Hardcoded: 1 CTA/SM for the decode attn kernel. + return {static_cast(num_sms), int64_t{1}, num_sms}; +} + +int64_t split_work_x( + const std::vector& kv_pages, + int64_t chunk_pages, + int64_t num_q_tiles, + bool split_kv) { + int64_t work = 0; + for (int32_t pages : kv_pages) { + const int64_t chunks = split_kv ? ceil_div(std::max(pages, 1), chunk_pages) : 1; + work += chunks * num_q_tiles; + } + return work; +} + +} // namespace + +// ============================================================================ +// GPU-only schedule launcher. All schedule arrays are computed in a single +// CUDA kernel from seqused_k on GPU - no D2H copy of seqused_k, no H2D copy +// of index arrays. Only the info_scalars summary tensor is read back to +// host (single small D2H sync) so the wrapper can size O_partial, launch +// the right kernel grid, and choose the split/non-split compile path. +// ============================================================================ +py::dict build_decode_schedule( + torch::Tensor seqused_k, + int64_t page_size, + int64_t seqlen_q, + int64_t num_qo_heads, + int64_t num_kv_heads, + int64_t head_dim, + int64_t max_seqlen_k, + bool enable_cuda_graph, + int64_t max_grid_size_override, + int64_t fixed_split_size, + bool disable_split_kv) { + TORCH_CHECK(seqused_k.is_cuda(), "seqused_k must be a CUDA tensor"); + TORCH_CHECK(seqused_k.scalar_type() == at::kInt, "seqused_k must be int32"); + TORCH_CHECK(seqused_k.dim() == 1, "seqused_k must have shape [B]"); + TORCH_CHECK(seqused_k.is_contiguous(), "seqused_k must be contiguous"); + TORCH_CHECK(page_size > 0, "page_size must be positive"); + TORCH_CHECK(seqlen_q > 0, "seqlen_q must be positive"); + TORCH_CHECK(num_qo_heads > 0 && num_kv_heads > 0, "head counts must be positive"); + TORCH_CHECK(num_qo_heads % num_kv_heads == 0, + "num_qo_heads must be divisible by num_kv_heads"); + TORCH_CHECK(num_qo_heads / num_kv_heads == kTargetDecodeQHeadPerKv, + "decode schedule currently supports only qhead_per_kv=16"); + TORCH_CHECK(head_dim == kTargetDecodeHeadDim, + "decode schedule currently supports only head_dim=128"); + TORCH_CHECK(max_seqlen_k > 0, "max_seqlen_k must be positive"); + + const int64_t batch = seqused_k.size(0); + TORCH_CHECK(batch > 0, "seqused_k must contain at least one batch item"); + // The single-CTA scheduler keeps per-batch state in shared memory + // (s_kv_pages[1024], etc.) with one thread per batch and tpb <= 1024, so + // batch must not exceed 1024; larger batches need a multi-CTA design. + TORCH_CHECK(batch <= 1024, + "build_decode_schedule supports batch <= 1024 (got ", + batch, + "); larger batches require a multi-CTA scheduler."); + + // Host-side derived constants (no D2H needed for these). + int dev_id = 0; + AT_CUDA_CHECK(cudaGetDevice(&dev_id)); + int compute_major = 0; + AT_CUDA_CHECK(cudaDeviceGetAttribute(&compute_major, + cudaDevAttrComputeCapabilityMajor, + dev_id)); + const int64_t qhead_per_kv = num_qo_heads / num_kv_heads; + const int64_t packed_q_len = seqlen_q * qhead_per_kv; + const int64_t cta_tile_q = determine_cta_tile_q(packed_q_len, head_dim, compute_major); + const int64_t num_q_tiles = ceil_div(packed_q_len, cta_tile_q); + TORCH_CHECK(kTargetDecodeTileM % qhead_per_kv == 0, + "decode tile_m must be divisible by qhead_per_kv"); + const int64_t q_tokens_per_group = kTargetDecodeTileM / qhead_per_kv; + const auto [max_grid_size, active_blocks_per_sm, num_sms] = + estimate_decode_grid_size(num_qo_heads, num_kv_heads, head_dim, + max_grid_size_override); + + // Worst-case padding for the work-tile arrays. When the heuristic picks + // the smallest possible chunk (min_chunk_pages = max(128/page_size, 1)), + // a single batch can produce up to max_pages_global / min_chunk_pages + // chunks; across batches this is bounded by sum-of-pages which is at most + // batch x max_pages_global. + const int64_t max_pages_global = ceil_div(max_seqlen_k, page_size); + int64_t pad_work = batch * num_q_tiles * std::max(max_pages_global, 1); + // CUDA-graph capture pads the work-tile count up to the fixed grid the + // graph was captured with. The kernel computes + // padded_wc = max(work_count, max(1, max_grid_size / num_kv_heads)) + // whenever (enable_cuda_graph && split_kv). The Python wrapper then does + // request_indices.narrow(0, 0, padded_work_count) and the tile-scheduler + // reads that many entries, so the output index arrays must be allocated to + // at least this padded count. For small max_seqlen_k the page-based bound + // (batch * num_q_tiles * max_pages_global) can fall below the graph pad, so + // take the max here to avoid an out-of-bounds narrow / scheduler read. + if (enable_cuda_graph) { + const int64_t graph_pad = + std::max(1, max_grid_size / std::max(num_kv_heads, 1)); + pad_work = std::max(pad_work, graph_pad); + } + const int64_t pad_partial = pad_work * q_tokens_per_group; + + const auto device = seqused_k.device(); + auto i32_options = torch::TensorOptions().dtype(torch::kInt32).device(device); + + // Allocate all output arrays on GPU. + auto kv_pages_tensor = torch::empty({batch}, i32_options); + auto split_counts_tensor = torch::empty({batch}, i32_options); + auto request_tensor = torch::empty({pad_work}, i32_options); + auto qo_tile_tensor = torch::empty({pad_work}, i32_options); + auto kv_tile_tensor = torch::empty({pad_work}, i32_options); + auto mask_tensor = torch::empty({pad_work}, i32_options); + auto merge_indptr_tensor = torch::empty({batch * seqlen_q + 1}, i32_options); + auto o_indptr_tensor = torch::empty({batch + 1}, i32_options); + auto info_tensor = torch::empty({5}, i32_options); + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(seqused_k.get_device()); + AT_CUDA_CHECK(cudaMemsetAsync( + request_tensor.data_ptr(), 0, pad_work * sizeof(int32_t), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + qo_tile_tensor.data_ptr(), 0, pad_work * sizeof(int32_t), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + kv_tile_tensor.data_ptr(), 0, pad_work * sizeof(int32_t), stream)); + AT_CUDA_CHECK(cudaMemsetAsync( + mask_tensor.data_ptr(), 0, pad_work * sizeof(int32_t), stream)); + + // tpb: threads per CTA. Use 128 (4 warps) so we have plenty of warps + // for the per-CTA setup phase and for scatter work. CTA 0's warp 0 + // does the decision phase; all warps in all CTAs collaborate on the + // scatter phase via grid-stride loop. + int tpb = 128; + if (batch > 128) { + // Cap to next power of two so single-CTA reductions still work. + tpb = 1; + while (tpb < static_cast(batch)) tpb <<= 1; + if (tpb > 1024) tpb = 1024; + } + + // Single CTA - all decisions and scatter happen on one CTA with 4 warps. + // No grid.sync(), no cooperative-launch overhead. + build_decode_schedule_gpu_kernel<<<1, tpb, 0, stream>>>( + seqused_k.data_ptr(), + static_cast(batch), + static_cast(page_size), + static_cast(seqlen_q), + static_cast(num_q_tiles), + static_cast(num_kv_heads), + static_cast(q_tokens_per_group), + static_cast(max_grid_size), + static_cast(fixed_split_size), + static_cast(disable_split_kv), + static_cast(enable_cuda_graph), + static_cast(pad_work), + kv_pages_tensor.data_ptr(), + split_counts_tensor.data_ptr(), + request_tensor.data_ptr(), + qo_tile_tensor.data_ptr(), + kv_tile_tensor.data_ptr(), + mask_tensor.data_ptr(), + merge_indptr_tensor.data_ptr(), + o_indptr_tensor.data_ptr(), + info_tensor.data_ptr()); + AT_CUDA_CHECK(cudaGetLastError()); + + // Single D2H sync for the 5 summary scalars. Payload = 20 bytes. + auto info_cpu = info_tensor.cpu(); + const int32_t* info_host = info_cpu.data_ptr(); + const int32_t split_kv_flag = info_host[0]; + const int32_t kv_chunk_size_pages = info_host[1]; + const int32_t padded_work_count = info_host[2]; + const int32_t max_split_count = info_host[3]; + const int32_t partial_rows = info_host[4]; + + const int64_t base_work_count = batch * num_q_tiles; + const int64_t base_cta = base_work_count * num_kv_heads; + const int64_t work_count = + (split_kv_flag != 0) ? static_cast(padded_work_count) : base_work_count; + + py::dict result; + result["split_kv"] = (split_kv_flag != 0); + result["cta_tile_q"] = cta_tile_q; + result["num_q_tiles"] = num_q_tiles; + result["kv_chunk_size_pages"] = static_cast(kv_chunk_size_pages); + result["kv_chunk_size_tokens"] = static_cast(kv_chunk_size_pages) * page_size; + result["work_count"] = work_count; + result["padded_work_count"] = static_cast(padded_work_count); + result["partial_rows"] = static_cast(partial_rows); + result["max_split_count"] = static_cast(max_split_count); + result["max_grid_size"] = max_grid_size; + result["active_blocks_per_sm"] = active_blocks_per_sm; + result["num_sms"] = num_sms; + result["base_cta"] = base_cta; + result["request_indices"] = request_tensor; + result["qo_tile_indices"] = qo_tile_tensor; + result["kv_tile_indices"] = kv_tile_tensor; + result["block_valid_mask"] = mask_tensor; + result["split_counts"] = split_counts_tensor; + result["kv_pages"] = kv_pages_tensor; + result["merge_indptr"] = merge_indptr_tensor; + result["o_indptr"] = o_indptr_tensor; + return result; +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("build_decode_schedule", &build_decode_schedule, + "Build paged decode split-KV schedule on GPU"); +} diff --git a/python/fmha_sm12x/cute/src/sm12x/prepare_k2q_csr.py b/python/fmha_sm12x/cute/src/sm12x/prepare_k2q_csr.py new file mode 100644 index 0000000..36c34a9 --- /dev/null +++ b/python/fmha_sm12x/cute/src/sm12x/prepare_k2q_csr.py @@ -0,0 +1,221 @@ +# SPDX-License-Identifier: MIT + +"""Sparse k2q CSR builder for SM120/SM121.""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from ._schedule import SparseAttentionSchedule, SPARSE_SCHEDULE_MODEL + +_SUPPORTED_TOPK = (4, 8, 16, 32) +_SUPPORTED_BLK_KV = 128 + + +def _ceil_div(x: int, y: int) -> int: + return (x + y - 1) // y + + +def _validate_cu_seqlens( + cu_seqlens: torch.Tensor, + *, + name: str, + expected_total: int | None = None, +) -> int: + if int(cu_seqlens.shape[0]) < 1: + raise ValueError(f"{name} must have shape [B + 1]") + if int(cu_seqlens[0].item()) != 0: + raise ValueError(f"{name}[0] must be 0") + diffs = cu_seqlens[1:] - cu_seqlens[:-1] + if bool((diffs < 0).any().item()): + raise ValueError(f"{name} must be monotonically non-decreasing") + total = int(cu_seqlens[-1].item()) + if expected_total is not None and total != int(expected_total): + raise ValueError( + f"{name}[-1] must match expected total {int(expected_total)}, got {total}" + ) + return total + + +def _row_bounds_from_cu_seqlens(cu_seqlens_k: torch.Tensor, blk_kv: int) -> tuple[int, int, int]: + lengths = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + if int(lengths.numel()) == 0: + return 0, 0, 0 + rows = torch.div(lengths + int(blk_kv) - 1, int(blk_kv), rounding_mode="floor") + return int(rows.sum().item()), int(rows.max().item()), int(lengths.max().item()) + + +class SparseK2qCsrBuilderSm12x: + """Build the k2q CSR reverse index with the SM12x CUDA helper pipeline.""" + + def __init__(self) -> None: + self._run = None + self._run_with_schedule = None + + def _ensure_loaded(self) -> None: + if self._run is None: + from .build_k2q_csr import ( + run_build_k2q_csr, + run_build_k2q_csr_with_schedule, + ) + + self._run = run_build_k2q_csr + self._run_with_schedule = run_build_k2q_csr_with_schedule + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + *, + total_k: int, + blk_kv: int = 128, + max_seqlen_k: Optional[int] = None, + max_seqlen_q: Optional[int] = None, + total_rows: Optional[int] = None, + qhead_per_kv: int = 1, + return_schedule: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor] | tuple[torch.Tensor, torch.Tensor, SparseAttentionSchedule]: + if blk_kv != _SUPPORTED_BLK_KV: + raise ValueError( + f"SparseK2qCsrBuilderSm12x only supports blk_kv == {_SUPPORTED_BLK_KV}, " + f"got {blk_kv}" + ) + if q2k_indices.dtype != torch.int32: + raise TypeError(f"q2k_indices must be torch.int32, got {q2k_indices.dtype}") + if q2k_indices.ndim != 3: + raise ValueError( + "q2k_indices must be rank-3 [head_kv, total_q, topK], " + f"got shape {tuple(q2k_indices.shape)}" + ) + if not q2k_indices.is_contiguous(): + raise ValueError("q2k_indices must be contiguous") + if cu_seqlens_q.dtype != torch.int32 or cu_seqlens_k.dtype != torch.int32: + raise TypeError("cu_seqlens_q and cu_seqlens_k must be torch.int32") + if cu_seqlens_q.ndim != 1 or cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_q and cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must share shape [B + 1]") + if not (q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda): + raise ValueError("all inputs must be CUDA tensors") + if q2k_indices.device != cu_seqlens_q.device or q2k_indices.device != cu_seqlens_k.device: + raise ValueError("all inputs must share a device") + if not cu_seqlens_q.is_contiguous() or not cu_seqlens_k.is_contiguous(): + raise ValueError("cu_seqlens_q and cu_seqlens_k must be contiguous") + + total_k = int(total_k) + if total_k < 0: + raise ValueError(f"total_k must be non-negative, got {total_k}") + + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + if topk not in _SUPPORTED_TOPK: + raise ValueError(f"SparseK2qCsrBuilderSm12x only supports topK in {_SUPPORTED_TOPK}, got {topk}") + + _validate_cu_seqlens(cu_seqlens_q, name="cu_seqlens_q", expected_total=total_q) + total_k_from_cu = _validate_cu_seqlens( + cu_seqlens_k, + name="cu_seqlens_k", + expected_total=total_k, + ) + total_rows_from_cu, max_kv_blocks, max_k_tokens_from_cu = _row_bounds_from_cu_seqlens( + cu_seqlens_k, + blk_kv, + ) + if return_schedule and max_seqlen_k is None: + raise ValueError("build_k2q_csr requires max_seqlen_k when return_schedule=True") + if max_seqlen_k is not None and int(max_seqlen_k) < max_k_tokens_from_cu: + raise ValueError("max_seqlen_k must cover every cu_seqlens_k segment") + if max_seqlen_k is not None: + max_kv_blocks = max(max_kv_blocks, _ceil_div(int(max_seqlen_k), blk_kv)) + if total_rows is not None and int(total_rows) != total_rows_from_cu: + raise ValueError( + "total_rows must match rows implied by cu_seqlens_k, " + f"got {int(total_rows)} vs {total_rows_from_cu}" + ) + total_rows = total_rows_from_cu + total_k = total_k_from_cu + nnz_upper_bound = total_q * topk + qhead_per_kv = int(qhead_per_kv) + if qhead_per_kv <= 0: + raise ValueError(f"qhead_per_kv must be positive, got {qhead_per_kv}") + if return_schedule: + if max_seqlen_q is None: + raise ValueError("build_k2q_csr requires max_seqlen_q when return_schedule=True") + max_seqlen_q = int(max_seqlen_q) + + device = q2k_indices.device + k2q_row_ptr = torch.empty((head_kv, total_rows + 1), dtype=torch.int32, device=device) + k2q_q_indices = torch.empty((head_kv, nnz_upper_bound), dtype=torch.int32, device=device) + schedule = None + if return_schedule: + target_q_per_cta = SPARSE_SCHEDULE_MODEL.balanced_target_q_per_cta( + total_q=total_q, + topk=topk, + blk_kv=blk_kv, + head_kv=head_kv, + qhead_per_kv=qhead_per_kv, + device=device, + ) + scheduler_metadata_capacity = SPARSE_SCHEDULE_MODEL.flat_schedule_capacity( + total_rows=total_rows, + total_q=total_q, + topk=topk, + head_kv=head_kv, + target_q_per_cta=target_q_per_cta, + ) + schedule = SparseAttentionSchedule( + enabled=True, + scheduler_metadata=torch.empty((scheduler_metadata_capacity, 6), dtype=torch.int32, device=device), + work_count=torch.empty((1,), dtype=torch.int32, device=device), + qsplit_indices=torch.empty_like(k2q_q_indices), + split_counts=torch.empty((total_q, head_kv), dtype=torch.int32, device=device), + target_q_per_cta=target_q_per_cta, + ) + + if total_rows == 0 or total_q == 0 or head_kv == 0 or topk == 0: + k2q_row_ptr.zero_() + k2q_q_indices.fill_(-1) + if schedule is not None: + schedule.work_count.zero_() + schedule.split_counts.zero_() + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices + + self._ensure_loaded() + with torch.cuda.nvtx.range("SparseK2qCsrSm12x_Pipeline"): + if schedule is None: + self._run( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + topk, + blk_kv, + total_rows, + max_kv_blocks, + ) + else: + self._run_with_schedule( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + k2q_row_ptr, + k2q_q_indices, + schedule.scheduler_metadata, + schedule.work_count, + schedule.qsplit_indices, + schedule.split_counts, + topk, + blk_kv, + total_rows, + max_kv_blocks, + schedule.target_q_per_cta, + schedule.work_capacity, + max_seqlen_q, + ) + if schedule is not None: + return k2q_row_ptr, k2q_q_indices, schedule + return k2q_row_ptr, k2q_q_indices diff --git a/python/fmha_sm12x/sparse.py b/python/fmha_sm12x/sparse.py new file mode 100644 index 0000000..e15e65e --- /dev/null +++ b/python/fmha_sm12x/sparse.py @@ -0,0 +1,334 @@ +# SPDX-License-Identifier: MIT + +"""SM12x sparse-attention public surface.""" + +from __future__ import annotations + +import math + +import torch + +from ._decode import SparseDecodePagedAttentionWrapper, sparse_decode_atten_func +from ._fp4 import fp4_indexer_block_scores +from ._nvfp4 import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + nvfp4_scale_128x4_offset, + quantize_bf16_to_nvfp4_128x4, + quantize_kv_bf16_to_nvfp4_128x4, + sparse_atten_nvfp4_kv_func, + swizzle_nvfp4_scale_to_128x4, +) +from ._lse import run_lse +from .api import fmha_sm12x, fmha_sm12x_plan, sparse_topk_select + +__all__ = [ + "SparseK2qCsrBuilderSm12x", + "build_k2q_csr", + "sparse_atten_func", + "sparse_atten_nvfp4_kv_func", + "sparse_decode_atten_func", + "SparseDecodePagedAttentionWrapper", + "fp4_indexer_block_scores", + "Nvfp4QuantizedTensor", + "quantize_bf16_to_nvfp4_128x4", + "quantize_kv_bf16_to_nvfp4_128x4", + "dequantize_nvfp4_128x4_to_bf16", + "swizzle_nvfp4_scale_to_128x4", + "nvfp4_global_scale_from_amax", + "nvfp4_scale_128x4_offset", + "sparse_topk_select", +] + + +def _rows_per_batch(cu_seqlens_k: torch.Tensor, block_size: int) -> list[int]: + vals = cu_seqlens_k.to("cpu", dtype=torch.int64, non_blocking=False).tolist() + return [(int(vals[i + 1]) - int(vals[i]) + block_size - 1) // block_size for i in range(len(vals) - 1)] + + +def _build_packed_row_map(rows: list[int]) -> tuple[list[list[int]], int]: + max_rows = max(rows, default=0) + row_map = [[-1 for _ in range(max_rows)] for _ in rows] + row_linear = 0 + for block in range(max_rows): + for batch, row_count in enumerate(rows): + if block < row_count: + row_map[batch][block] = row_linear + row_linear += 1 + return row_map, row_linear + + +_OPTIMIZED_K2Q_TOPK = (4, 8, 16, 32) +_OPTIMIZED_K2Q_BLOCK_SIZE = 128 + + +def _can_use_optimized_k2q( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + blk_kv: int, + kwargs: dict, +) -> bool: + if "total_k" not in kwargs or int(blk_kv) != _OPTIMIZED_K2Q_BLOCK_SIZE: + return False + if q2k_indices.dtype != torch.int32 or q2k_indices.ndim != 3: + return False + if int(q2k_indices.shape[2]) not in _OPTIMIZED_K2Q_TOPK: + return False + return bool(q2k_indices.is_cuda and cu_seqlens_q.is_cuda and cu_seqlens_k.is_cuda) + + +def _validate_cu_seqlens_pair( + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, +) -> None: + if cu_seqlens_q.dtype != torch.int32: + raise TypeError(f"cu_seqlens_q must be torch.int32, got {cu_seqlens_q.dtype}") + if cu_seqlens_k.dtype != torch.int32: + raise TypeError(f"cu_seqlens_k must be torch.int32, got {cu_seqlens_k.dtype}") + if cu_seqlens_q.ndim != 1: + raise ValueError("cu_seqlens_q must be rank-1") + if cu_seqlens_k.ndim != 1: + raise ValueError("cu_seqlens_k must be rank-1") + if cu_seqlens_q.shape != cu_seqlens_k.shape: + raise ValueError("cu_seqlens_q and cu_seqlens_k must share shape [B + 1]") + + +def _compact_page_table(page_table: torch.Tensor | None, cu_seqlens_k: torch.Tensor, block_size: int) -> torch.Tensor | None: + if page_table is None: + return None + pages: list[torch.Tensor] = [] + rows = _rows_per_batch(cu_seqlens_k, int(block_size)) + for batch, row_count in enumerate(rows): + if row_count > 0: + pages.append(page_table[batch, :row_count]) + if not pages: + return torch.empty((0,), dtype=torch.int32, device=page_table.device) + return torch.cat(pages).contiguous() + + +def build_k2q_csr( + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + kv_block_size: int, + *, + total_k: int | None = None, + return_schedule: bool = False, + **_kwargs, +) -> tuple[torch.Tensor, torch.Tensor]: + """Torch reference q2k -> k2q CSR builder for SM12x validation. + + This pure-Torch reference returns only ``(k2q_row_ptr, k2q_q_indices)``. + The fused sparse-attention schedule (``return_schedule=True`` in the SM100 + surface) requires the optimized CUDA builder, so route through + ``SparseK2qCsrBuilderSm12x`` for that path rather than failing with a + silent 2-tuple that a 3-tuple caller would mis-unpack. + """ + + _ = total_k + if return_schedule: + raise ValueError( + "return_schedule=True is not supported by the Torch reference " + "build_k2q_csr; use SparseK2qCsrBuilderSm12x for the fused schedule." + ) + if q2k_indices.dtype != torch.int32 or q2k_indices.ndim != 3: + raise ValueError("q2k_indices must be int32 with shape [Hkv, total_q, topK]") + _validate_cu_seqlens_pair(cu_seqlens_q, cu_seqlens_k) + head_kv, total_q, topk = (int(v) for v in q2k_indices.shape) + rows = _rows_per_batch(cu_seqlens_k, int(kv_block_size)) + row_map, total_rows = _build_packed_row_map(rows) + row_ptr = torch.zeros((head_kv, total_rows + 1), dtype=torch.int32, device=q2k_indices.device) + q_indices = torch.full((head_kv, total_q * topk), -1, dtype=torch.int32, device=q2k_indices.device) + q_starts = cu_seqlens_q.to("cpu", dtype=torch.int64, non_blocking=False).tolist() + buckets: list[list[list[int]]] = [[[] for _ in range(total_rows)] for _ in range(head_kv)] + q2k_cpu = q2k_indices.to("cpu", non_blocking=False) + for batch, row_count in enumerate(rows): + for local_q in range(int(q_starts[batch + 1]) - int(q_starts[batch])): + q_global = int(q_starts[batch]) + local_q + for head in range(head_kv): + for item in q2k_cpu[head, q_global].tolist(): + block = int(item) + if block >= 0: + if block >= row_count: + raise ValueError(f"q2k_indices block index out of range for batch {batch}") + buckets[head][row_map[batch][block]].append(local_q) + for head in range(head_kv): + cursor = 0 + for row, entries in enumerate(buckets[head]): + row_ptr[head, row] = cursor + for value in sorted(entries): + q_indices[head, cursor] = int(value) + cursor += 1 + row_ptr[head, total_rows] = cursor + return row_ptr, q_indices + + +class SparseK2qCsrBuilderSm12x: + """CSR builder with an optimized SM12x CUDA path and reference fallback.""" + + def __init__(self, *, use_optimized: bool = True) -> None: + self._use_optimized = bool(use_optimized) + self._optimized = None + + def _optimized_builder(self): + if self._optimized is None: + from .cute.src.sm12x.prepare_k2q_csr import ( + SparseK2qCsrBuilderSm12x as _OptimizedSparseK2qCsrBuilderSm12x, + ) + + self._optimized = _OptimizedSparseK2qCsrBuilderSm12x() + return self._optimized + + def __call__( + self, + q2k_indices: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + **kwargs, + ): + _validate_cu_seqlens_pair(cu_seqlens_q, cu_seqlens_k) + blk_kv = int(kwargs.get("blk_kv", _OPTIMIZED_K2Q_BLOCK_SIZE)) + if self._use_optimized and _can_use_optimized_k2q( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + blk_kv, + kwargs, + ): + return self._optimized_builder()( + q2k_indices, + cu_seqlens_q, + cu_seqlens_k, + **kwargs, + ) + if kwargs.get("return_schedule", False): + raise ValueError( + "return_schedule=True requires the optimized CUDA SM12x CSR builder" + ) + return build_k2q_csr(q2k_indices, cu_seqlens_q, cu_seqlens_k, blk_kv, **kwargs) + + +def _q2k_from_csr(k2q_row_ptr: torch.Tensor, k2q_q_indices: torch.Tensor, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, topk: int, blk_kv: int) -> torch.Tensor: + head_kv = int(k2q_row_ptr.shape[0]) + total_q = int(cu_seqlens_q[-1].item()) + q2k = torch.full((head_kv, total_q, int(topk)), -1, dtype=torch.int32, device=k2q_row_ptr.device) + rows = _rows_per_batch(cu_seqlens_k, int(blk_kv)) + row_map, total_rows = _build_packed_row_map(rows) + if int(k2q_row_ptr.shape[1]) != total_rows + 1: + raise ValueError("k2q_row_ptr row count does not match cu_seqlens_k") + q_starts = cu_seqlens_q.to("cpu", dtype=torch.int64, non_blocking=False).tolist() + for batch, row_count in enumerate(rows): + qo_len = int(q_starts[batch + 1]) - int(q_starts[batch]) + for block in range(row_count): + row = row_map[batch][block] + for head in range(head_kv): + begin = int(k2q_row_ptr[head, row].item()) + end = int(k2q_row_ptr[head, row + 1].item()) + if begin < 0 or end < begin or end > int(k2q_q_indices.shape[1]): + raise ValueError("k2q row pointers are out of range") + for local_q_value in k2q_q_indices[head, begin:end].to("cpu", non_blocking=False).tolist(): + local_q = int(local_q_value) + if local_q < 0 or local_q >= qo_len: + raise ValueError(f"k2q query index out of range for batch {batch} (q index)") + q_global = int(q_starts[batch]) + local_q + slot = int((q2k[head, q_global] >= 0).sum().item()) + if slot < topk: + q2k[head, q_global, slot] = block + return q2k + + +_FP8_E4M3 = torch.float8_e4m3fn + + +def _stage_attention_dtypes( + q: torch.Tensor, k: torch.Tensor, v: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Accept SM100's forward dtype combinations, staging FP8 to BF16. + + Like the SM100 path, q/k/v may all share a dtype (BF16/FP16/FP8 E4M3) or be + a BF16 query with an FP8 E4M3 K/V cache. FP8 operands are dequantized to + BF16 (the SM100 FP8 path stages QK/PV in BF16), so the downstream Triton / + Torch reference runs in BF16 and matches the dequantized-BF16 reference. + """ + + same = q.dtype == k.dtype == v.dtype + fp8_kv_bf16_q = ( + q.dtype == torch.bfloat16 and k.dtype == _FP8_E4M3 and v.dtype == _FP8_E4M3 + ) + if not same and not fp8_kv_bf16_q: + raise TypeError( + "q, k, v must share a dtype, except a bf16 query with fp8_e4m3 K/V; " + f"got q={q.dtype}, k={k.dtype}, v={v.dtype}" + ) + + def _deq(t: torch.Tensor) -> torch.Tensor: + return t.to(torch.bfloat16) if t.dtype == _FP8_E4M3 else t + + return _deq(q), _deq(k), _deq(v) + + +def sparse_atten_func(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, k2q_row_ptr: torch.Tensor, k2q_q_indices: torch.Tensor, topK: int, *, cu_seqlens_q: torch.Tensor, cu_seqlens_k: torch.Tensor, max_seqlen_q: int, max_seqlen_k: int, blk_kv: int = 128, causal: bool = False, softmax_scale: float | None = None, lse_temperature_scale: float = 1.0, return_temperature_lse: bool = False, partial_dtype: torch.dtype = torch.bfloat16, return_softmax_lse: bool = False, page_table: torch.Tensor | None = None, seqused_k: torch.Tensor | None = None, schedule: object | None = None, usable_SM_count: int = -1, qk_dtype: torch.dtype | None = None, pv_dtype: torch.dtype | None = None, **_kwargs): + """Block-sparse varlen attention for SM12x. + + Uses the fused Triton kernel (:mod:`fmha_sm12x._triton_sparse`) for dense + KV when Triton is importable, and the Torch reference otherwise (paged KV, + or no Triton). Mirrors the SM100 ``sparse_atten_func`` surface: + ``schedule``, ``usable_SM_count``, ``partial_dtype``, ``qk_dtype`` and + ``pv_dtype`` are accepted for API compatibility but do not affect this + forward. ``return_temperature_lse`` returns a third LSE computed with + logits scaled by ``softmax_scale / lse_temperature_scale``. + """ + + _ = (max_seqlen_q, max_seqlen_k, seqused_k, schedule, usable_SM_count, partial_dtype, qk_dtype, pv_dtype) + lse_temperature_scale = float(lse_temperature_scale) + if not math.isfinite(lse_temperature_scale) or lse_temperature_scale <= 0.0: + raise ValueError( + f"lse_temperature_scale must be finite and > 0, got {lse_temperature_scale}" + ) + if bool(return_temperature_lse) and not bool(return_softmax_lse): + raise ValueError("return_temperature_lse=True requires return_softmax_lse=True") + q, k, v = _stage_attention_dtypes(q, k, v) + q2k = _q2k_from_csr(k2q_row_ptr, k2q_q_indices, cu_seqlens_q, cu_seqlens_k, int(topK), int(blk_kv)) + kv_heads = int(k.shape[1]) if k.ndim == 4 else int(k.shape[-2]) + block_indexes = q2k.permute(1, 0, 2).contiguous() + resolved_scale = float(softmax_scale) if softmax_scale is not None else 1.0 / math.sqrt(int(q.shape[-1])) + + if page_table is None and int(block_indexes.shape[1]) == kv_heads: + from ._triton_sparse import triton_dense_supported, triton_sparse_atten_dense + + if triton_dense_supported(q, k, v, int(blk_kv)): + out, lse, temperature_lse = triton_sparse_atten_dense( + q, k, v, block_indexes, + cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, + num_kv_heads=kv_heads, page_size=int(blk_kv), causal=bool(causal), + sm_scale=resolved_scale, return_lse=bool(return_softmax_lse), + lse_temperature_scale=lse_temperature_scale, + return_temperature_lse=bool(return_temperature_lse), + out_dtype=torch.bfloat16, + ) + if not return_softmax_lse: + return out + if bool(return_temperature_lse): + return out, lse, temperature_lse + return out, lse + + qo_lens = cu_seqlens_q[1:] - cu_seqlens_q[:-1] + kv_lens = cu_seqlens_k[1:] - cu_seqlens_k[:-1] + plan = fmha_sm12x_plan(qo_lens, kv_lens, int(q.shape[1]), kv_heads, page_size=int(blk_kv), causal=bool(causal)) + kv_indices = _compact_page_table(page_table, cu_seqlens_k, int(blk_kv)) + out, _ = fmha_sm12x(q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes, sm_scale=softmax_scale) + if not return_softmax_lse: + return out + lse = run_lse( + q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes, + sm_scale=resolved_scale, + ) + if bool(return_temperature_lse): + temperature_lse = run_lse( + q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes, + sm_scale=resolved_scale / lse_temperature_scale, + ) + return out, lse, temperature_lse + return out, lse diff --git a/python/minimax_msa/__init__.py b/python/minimax_msa/__init__.py new file mode 100644 index 0000000..8978f14 --- /dev/null +++ b/python/minimax_msa/__init__.py @@ -0,0 +1,3 @@ +# SPDX-License-Identifier: MIT + +"""Shared helpers for MiniMax Sparse Attention packages.""" diff --git a/python/minimax_msa/arch.py b/python/minimax_msa/arch.py new file mode 100644 index 0000000..4840684 --- /dev/null +++ b/python/minimax_msa/arch.py @@ -0,0 +1,184 @@ +# SPDX-License-Identifier: MIT + +"""CUDA architecture selection for MiniMax Sparse Attention JIT builds.""" + +from __future__ import annotations + +import os +import re +import shlex +from dataclasses import dataclass +from typing import Final + +_CUDA_ARCH_ENV: Final = "MSA_CUDA_ARCH" +_LEGACY_CUDA_ARCH_ENV: Final = "FMHA_SM100_CUDA_ARCH" +_NVCC_GENCODES_ENV: Final = "MSA_NVCC_GENCODES" +_LEGACY_NVCC_GENCODES_ENV: Final = "FMHA_SM100_NVCC_GENCODES" +_ARCH_RE: Final = re.compile(r"^(?:sm_|compute_)?(?P\d{2,3})(?Pa?)$") +_SM100_GENCODE_CODE_MARKERS: Final = ( + "code=sm_100a", + "code=sm_103a", + "code=compute_100a", + "code=compute_103a", +) +_SUPPORTED_SM100_TARGETS: Final = ("sm_100a", "sm_103a") +_SM12X_GENCODE_CODE_MARKERS: Final = ( + "code=sm_120", + "code=sm_121", + "code=compute_120", + "code=compute_121", +) +_SUPPORTED_SM12X_TARGETS: Final = ("sm_120", "sm_121") + + +class UnsupportedCudaArchError(RuntimeError): + """Raised when an architecture-specific kernel is requested for the wrong SM.""" + + +@dataclass(frozen=True, slots=True) +class CudaArch: + """CUDA virtual and real architecture pair.""" + + compute: str + code: str + + @property + def cache_suffix(self) -> str: + return self.code.replace("sm_", "_sm") + + @property + def arch_flag(self) -> str: + return f"-arch={self.code}" + + @property + def gencode_flag(self) -> str: + return f"-gencode=arch={self.compute},code={self.code}" + + +_DEFAULT_GENCODES: Final = ( + CudaArch(compute="compute_100a", code="sm_100a"), + CudaArch(compute="compute_103a", code="sm_103a"), +) + + +def _first_env(*names: str) -> str | None: + for name in names: + value = os.environ.get(name) + if value: + return value + return None + + +def _parse_cuda_arch(value: str) -> CudaArch: + normalized = value.strip().lower() + match = _ARCH_RE.fullmatch(normalized) + if match is None: + raise ValueError( + f"{_CUDA_ARCH_ENV} must look like sm_100a, sm_103a, sm_120, or sm_121; " + f"got {value!r}" + ) + digits = match.group("major") + suffix = match.group("suffix") + return CudaArch(compute=f"compute_{digits}{suffix}", code=f"sm_{digits}{suffix}") + + +def _detect_device_arch() -> CudaArch | None: + try: + import torch + except ImportError: + return None + + if not torch.cuda.is_available(): + return None + major, minor = torch.cuda.get_device_capability() + if major == 10 and minor == 0: + return CudaArch(compute="compute_100a", code="sm_100a") + if major == 10 and minor == 3: + return CudaArch(compute="compute_103a", code="sm_103a") + if major == 12 and minor in (0, 1): + digits = f"{major}{minor}" + return CudaArch(compute=f"compute_{digits}", code=f"sm_{digits}") + return None + + +def selected_cuda_arch() -> CudaArch | None: + """Return the explicit or detected single-arch CUDA target.""" + + explicit = _first_env(_CUDA_ARCH_ENV, _LEGACY_CUDA_ARCH_ENV) + if explicit: + return _parse_cuda_arch(explicit) + return _detect_device_arch() + + +def nvcc_gencode_flags() -> list[str]: + """Return gencode flags for csrc JIT builds.""" + + explicit_gencodes = _first_env(_NVCC_GENCODES_ENV, _LEGACY_NVCC_GENCODES_ENV) + if explicit_gencodes: + return shlex.split(explicit_gencodes) + arch = selected_cuda_arch() + if arch is not None: + return [arch.gencode_flag] + return [arch.gencode_flag for arch in _DEFAULT_GENCODES] + + + +def require_sm100_csrc_arch(component: str) -> None: + """Reject non-SM100-family targets for tcgen05/TMEM kernels.""" + + gencodes = nvcc_gencode_flags() + if all( + any(marker in flag for marker in _SM100_GENCODE_CODE_MARKERS) + for flag in gencodes + ): + return + selected = " ".join(gencodes) + supported = ", ".join(_SUPPORTED_SM100_TARGETS) + raise UnsupportedCudaArchError( + f"{component} is an SM100/SM103-only tcgen05/TMEM kernel and does not " + f"support CUDA target {selected}; supported targets: {supported}. " + "Add separate SM12x kernels instead of aliasing fmha_sm100 on SM120/SM121." + ) + + +def require_sm12x_csrc_arch(component: str) -> None: + """Reject non-SM12x-family targets for SM120/SM121 kernels.""" + + gencodes = nvcc_gencode_flags() + if all( + any(marker in flag for marker in _SM12X_GENCODE_CODE_MARKERS) + for flag in gencodes + ): + return + selected = " ".join(gencodes) + supported = ", ".join(_SUPPORTED_SM12X_TARGETS) + raise UnsupportedCudaArchError( + f"{component} is an SM120/SM121 kernel and does not support CUDA " + f"target {selected}; supported targets: {supported}." + ) + + +def ensure_sm100_kernel_arch(package: str) -> None: + """Compatibility alias for SM100-only kernel guards.""" + + require_sm100_csrc_arch(package) + + +def cpp_extension_arch_flag() -> str: + """Return the torch cpp_extension CUDA arch flag.""" + + arch = selected_cuda_arch() + if arch is not None: + return arch.arch_flag + return CudaArch(compute="compute_100", code="sm_100").arch_flag + + +def cuda_arch_cache_suffix() -> str: + """Return a stable cache suffix for non-default architecture selections.""" + + explicit_gencodes = _first_env(_NVCC_GENCODES_ENV, _LEGACY_NVCC_GENCODES_ENV) + if explicit_gencodes: + digest = re.sub(r"[^0-9A-Za-z]+", "_", explicit_gencodes).strip("_") + return f"_{digest}" if digest else "_custom_arch" + arch = selected_cuda_arch() + return arch.cache_suffix if arch is not None else "" diff --git a/tests/conftest.py b/tests/conftest.py index 08eed56..70cb568 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,76 @@ # SPDX-FileCopyrightText: Copyright (c) 2026 MiniMax # SPDX-License-Identifier: MIT +"""Pytest configuration for architecture-specific kernel tests.""" + +from __future__ import annotations + import sys from pathlib import Path +import pytest + sys.path.insert(0, str(Path(__file__).resolve().parent.parent / "python")) + +# Tests exercising the SM100 tcgen05/TMEM kernels; skipped off SM100/SM103. +_SM100_ONLY_TESTS = ( + "tests/integration/", + "tests/regression/", +) +# Tests exercising the SM120/SM121 fmha_sm12x package; skipped on SM100/SM103. +_SM12X_ONLY_TESTS = ( + "tests/test_sm12x_reference.py", + "tests/test_sm12x_triton_sparse.py", + "tests/test_sm12x_equivalence.py", +) + + +def _cuda_capability() -> tuple[int, int] | None: + try: + import torch + except ImportError: + return None + if not torch.cuda.is_available(): + return None + major, minor = torch.cuda.get_device_capability() + return int(major), int(minor) + + +def _is_sm100_family(capability: tuple[int, int] | None) -> bool: + return capability in ((10, 0), (10, 3)) + + +def _skipped_prefixes(capability: tuple[int, int] | None) -> tuple[str, ...]: + """Tests for the *other* architecture family are skipped. + + SM100/SM103 runs the SM100 suite and skips the SM12x suite; every other + device (including SM120/SM121) does the reverse. Arch-agnostic tests + (e.g. tests/test_arch.py, the arch-adaptive smoke tests) are in neither + list and always run. + """ + + if _is_sm100_family(capability): + return _SM12X_ONLY_TESTS + return _SM100_ONLY_TESTS + + +def pytest_ignore_collect(collection_path: Path, config: pytest.Config) -> bool: + _ = config + path = collection_path.as_posix() + return any(part in path for part in _skipped_prefixes(_cuda_capability())) + + +def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None: + _ = config + capability = _cuda_capability() + skipped = _skipped_prefixes(capability) + reason = ( + "SM12x-only test; requires SM120/SM121" + if _is_sm100_family(capability) + else "SM100-only tcgen05/TMEM kernel; requires SM100/SM103" + ) + skip_marker = pytest.mark.skip(reason=reason) + for item in items: + path = item.path.as_posix() + if any(part in path for part in skipped): + item.add_marker(skip_marker) diff --git a/tests/smoke/test_sparse_topk_forced.py b/tests/smoke/test_sparse_topk_forced.py index e9c9541..9c2fabb 100644 --- a/tests/smoke/test_sparse_topk_forced.py +++ b/tests/smoke/test_sparse_topk_forced.py @@ -12,7 +12,13 @@ sys.path.insert(0, str(Path(__file__).resolve().parents[2] / "python")) import torch -from fmha_sm100 import sparse_topk_select + +# sparse_topk_select is the same kernel in both packages; use the one built for +# the running arch (fmha_sm100 targets SM100/SM103, fmha_sm12x targets SM120/121). +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] == 12: + from fmha_sm12x import sparse_topk_select +else: + from fmha_sm100 import sparse_topk_select def test_forced_blocks( @@ -79,7 +85,7 @@ def test_forced_zero_is_noop( force_begin_blocks=0, force_end_blocks=0, ) assert torch.equal(result_noop, result_zero), ( - f"[FAIL] force_begin=0, force_end=0 differs from default" + "[FAIL] force_begin=0, force_end=0 differs from default" ) print(" [PASS] force_begin=0, force_end=0 == no-force (bitwise identical)") @@ -178,8 +184,8 @@ def test_forced_large_k( if __name__ == "__main__": dev = torch.device("cuda") p = torch.cuda.get_device_properties(dev) - if not (p.major == 10 and p.minor in (0, 3)): - print("SKIP: SM100/SM103 GPU not available") + if (p.major, p.minor) not in ((10, 0), (10, 3), (12, 0), (12, 1)): + print("SKIP: Blackwell SM100/SM103/SM12x GPU not available") sys.exit(0) print("=== Testing forced block selection ===") diff --git a/tests/test_arch.py b/tests/test_arch.py new file mode 100644 index 0000000..1cfd0ce --- /dev/null +++ b/tests/test_arch.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: MIT + +"""Tests for runtime CUDA architecture flag selection.""" + +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "python")) + +from minimax_msa import arch + + +def test_default_arch_flags_when_no_cuda_is_detected(monkeypatch: pytest.MonkeyPatch) -> None: + """Default builds preserve the original SM100/SM103 csrc targets.""" + + monkeypatch.delenv("MSA_CUDA_ARCH", raising=False) + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + monkeypatch.setattr(arch, "_detect_device_arch", lambda: None) + + assert arch.nvcc_gencode_flags() == [ + "-gencode=arch=compute_100a,code=sm_100a", + "-gencode=arch=compute_103a,code=sm_103a", + ] + assert arch.cpp_extension_arch_flag() == "-arch=sm_100" + assert arch.cuda_arch_cache_suffix() == "" + + +def test_explicit_sm121_arch_selects_single_target(monkeypatch: pytest.MonkeyPatch) -> None: + """MSA_CUDA_ARCH=sm_121 selects SM121 flags and cache names.""" + + monkeypatch.setenv("MSA_CUDA_ARCH", "sm_121") + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + + assert arch.nvcc_gencode_flags() == ["-gencode=arch=compute_121,code=sm_121"] + assert arch.cpp_extension_arch_flag() == "-arch=sm_121" + assert arch.cuda_arch_cache_suffix() == "_sm121" + + + +def test_sm12x_topk_loader_targets_sm121( + monkeypatch: pytest.MonkeyPatch, tmp_path: Path +) -> None: + """The SM12x sparse_topk loader compiles the shared kernel for SM121. + + fmha_sm100 is untouched (its csrc still targets SM100/SM103); the SM12x + arch routing lives entirely in fmha_sm12x._topk. + """ + + monkeypatch.setenv("MSA_CUDA_ARCH", "sm_121") + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + + from fmha_sm12x import _topk + + flags = _topk._nvcc_flags(tmp_path, tmp_path, tmp_path) + assert "-gencode=arch=compute_121,code=sm_121" in flags + assert "-gencode=arch=compute_100a,code=sm_100a" not in flags + + +def test_fmha_sm100_jit_is_arch_agnostic_source() -> None: + """fmha_sm100 carries no SM12x/arch-routing coupling (zero-diff PR goal).""" + + jit_src = ( + Path(__file__).resolve().parents[1] / "python/fmha_sm100/jit.py" + ).read_text() + assert "minimax_msa" not in jit_src + assert "-gencode=arch=compute_100a,code=sm_100a" in jit_src + + +def test_sm12x_csrc_guard_accepts_sm121(monkeypatch: pytest.MonkeyPatch) -> None: + """SM12x helper kernels accept SM121 targets.""" + + monkeypatch.setenv("MSA_CUDA_ARCH", "sm_121") + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + + arch.require_sm12x_csrc_arch("fmha_sm12x.k2q_csr") + + +def test_sm12x_csrc_guard_rejects_default_sm100(monkeypatch: pytest.MonkeyPatch) -> None: + """SM12x helper kernels reject default SM100-only gencodes.""" + + monkeypatch.delenv("MSA_CUDA_ARCH", raising=False) + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + monkeypatch.setattr(arch, "_detect_device_arch", lambda: None) + + with pytest.raises(arch.UnsupportedCudaArchError, match="SM120/SM121"): + arch.require_sm12x_csrc_arch("fmha_sm12x.k2q_csr") + + +def test_explicit_gencodes_override_single_arch(monkeypatch: pytest.MonkeyPatch) -> None: + """MSA_NVCC_GENCODES overrides the generated gencode list.""" + + gencodes = "-gencode=arch=compute_120,code=sm_120 -gencode=arch=compute_121,code=sm_121" + monkeypatch.setenv("MSA_CUDA_ARCH", "sm_100a") + monkeypatch.setenv("MSA_NVCC_GENCODES", gencodes) + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + + assert arch.nvcc_gencode_flags() == [ + "-gencode=arch=compute_120,code=sm_120", + "-gencode=arch=compute_121,code=sm_121", + ] + assert "compute_120" in arch.cuda_arch_cache_suffix() + + +def test_invalid_explicit_arch_is_rejected(monkeypatch: pytest.MonkeyPatch) -> None: + """Invalid arch names fail before invoking nvcc.""" + + monkeypatch.setenv("MSA_CUDA_ARCH", "blackwell") + monkeypatch.delenv("FMHA_SM100_CUDA_ARCH", raising=False) + monkeypatch.delenv("MSA_NVCC_GENCODES", raising=False) + monkeypatch.delenv("FMHA_SM100_NVCC_GENCODES", raising=False) + + with pytest.raises(ValueError, match="MSA_CUDA_ARCH"): + arch.nvcc_gencode_flags() + + +def test_sm12x_facade_resolves_parallel_kernel_routes() -> None: + """The SM12x namespace resolves SM12x routes without aliasing SM100 sparse.""" + + import fmha_sm12x + + assert "fmha_sm12x" in fmha_sm12x.__all__ + assert "sparse_topk_select" in fmha_sm12x.__all__ + assert "Nvfp4QuantizedTensor" in fmha_sm12x.__all__ + assert callable(fmha_sm12x.dequantize_nvfp4_128x4_to_bf16) + assert callable(fmha_sm12x.nvfp4_scale_128x4_offset) + assert callable(fmha_sm12x.fp4_indexer_block_scores) + assert callable(fmha_sm12x.sparse_atten_nvfp4_kv_func) + assert callable(fmha_sm12x.sparse_decode_atten_func) + assert callable(fmha_sm12x.SparseDecodePagedAttentionWrapper) + + +def test_sm12x_arch_facade_imports_without_heavy_cuda_deps() -> None: + """The SM12x namespace exposes the shared arch helper as a light import.""" + + import fmha_sm12x.arch as sm12x_arch + + assert sm12x_arch.nvcc_gencode_flags is arch.nvcc_gencode_flags diff --git a/tests/test_sm12x_equivalence.py b/tests/test_sm12x_equivalence.py new file mode 100644 index 0000000..4d55463 --- /dev/null +++ b/tests/test_sm12x_equivalence.py @@ -0,0 +1,432 @@ +# SPDX-License-Identifier: MIT + +"""SM12x equivalents of the SM100 behavioural tests. + +Each test mirrors a behaviour the SM100 suite checks (proxy-KV pipeline, +onlyscore output, q-offset override, paged sparse attention, the FP4 indexer, +and paged decode) against an independent Torch oracle, using only the +``fmha_sm12x`` public surface. FP8/tcgen05-specific SM100 tests have no SM12x +analog and are intentionally absent. +""" + +from __future__ import annotations + +import math + +import pytest +import torch + + +def _need_cuda() -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + +def _sparse_ref_dense(q, k, v, block_ids, cu_q, cu_k, *, page_size, sm_scale, causal): + """Independent oracle: per-(token, kv_head) block-sparse attention. + + ``block_ids`` is ``[total_q, Hkv, topk]`` (batch-local block ids, -1 pad), + shared across each GQA group. Returns float32 ``[total_q, Hq, D]``. + """ + + total_q, num_qo_heads, head_dim = q.shape + num_kv_heads = k.shape[1] + h_ratio = num_qo_heads // num_kv_heads + out = torch.zeros((total_q, num_qo_heads, head_dim), dtype=torch.float32, device=q.device) + cuq = cu_q.tolist() + cuk = cu_k.tolist() + for b in range(len(cuq) - 1): + q0, q1, k0, k1 = cuq[b], cuq[b + 1], cuk[b], cuk[b + 1] + kv_len, qo_len = k1 - k0, q1 - q0 + for lq in range(qo_len): + qi = q0 + lq + vis = (kv_len - qo_len) + lq + for h in range(num_qo_heads): + kvh = h // h_ratio + pos = [] + for bid in block_ids[qi, kvh].tolist(): + if bid < 0 or bid * page_size >= kv_len: + continue + for p in range(bid * page_size, min(bid * page_size + page_size, kv_len)): + if (not causal) or p <= vis: + pos.append(p) + if not pos: + continue + idx = torch.tensor(pos, device=q.device, dtype=torch.long) + k_sel = k[k0:k1][idx, kvh].float() + v_sel = v[k0:k1][idx, kvh].float() + logits = (k_sel @ q[qi, h].float()) * sm_scale + out[qi, h] = torch.softmax(logits, dim=0) @ v_sel + return out + + +def _cos(a, b): + return torch.nn.functional.cosine_similarity( + a.float().reshape(-1), b.float().reshape(-1), dim=0 + ).item() + + +def test_proxy_kv_e2e_pipeline_matches_reference(): + # max_score (dense proxy) -> sparse_topk_select -> CSR -> sparse_atten_func, + # all via fmha_sm12x, vs a Torch oracle over the kernel-selected blocks. + _need_cuda() + from fmha_sm12x import ( + build_k2q_csr, + fmha_sm12x, + fmha_sm12x_plan, + sparse_atten_func, + sparse_topk_select, + ) + + torch.manual_seed(0) + dev = torch.device("cuda") + total_q, num_kv_heads, h_ratio, head_dim, page_size, topk = 4, 2, 2, 128, 128, 16 + num_qo_heads = num_kv_heads * h_ratio + n_pages = 20 + kv_len = n_pages * page_size + cu_q = torch.tensor([0, total_q], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + + # Proxy cache (num_qo_heads_proxy == num_kv_heads_real, MQA-compressed). + proxy_q = torch.randn(total_q, num_kv_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + proxy_k = torch.randn(kv_len, 1, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + proxy_v = torch.randn_like(proxy_k) + sm_scale = 1.0 / math.sqrt(head_dim) + + proxy_plan = fmha_sm12x_plan( + cu_q[1:] - cu_q[:-1], cu_k[1:] - cu_k[:-1], num_kv_heads, 1, + page_size=page_size, output_maxscore=True, causal=True, + ) + _, max_score = fmha_sm12x( + proxy_q, proxy_k, proxy_v, proxy_plan, sm_scale=sm_scale, + output_o=False, output_maxscore=True, + ) + assert max_score is not None and max_score.shape[0] == num_kv_heads + + block_ids = sparse_topk_select(max_score.contiguous(), topk, num_valid_pages=n_pages) + assert block_ids.shape == (total_q, num_kv_heads, topk) + + # Real (GQA) cache, attended sparsely with the selected blocks. + real_k = torch.randn(kv_len, num_kv_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + real_v = torch.randn_like(real_k) + real_q = torch.randn(total_q, num_qo_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + q2k = block_ids.permute(1, 0, 2).contiguous() + row_ptr, q_idx = build_k2q_csr(q2k, cu_q, cu_k, page_size) + out = sparse_atten_func( + real_q, real_k, real_v, row_ptr, q_idx, topk, + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=total_q, max_seqlen_k=kv_len, + blk_kv=page_size, causal=True, softmax_scale=sm_scale, + ) + assert out.shape == (total_q, num_qo_heads, head_dim) and not out.isnan().any() + + ref = _sparse_ref_dense( + real_q, real_k, real_v, block_ids, cu_q, cu_k, + page_size=page_size, sm_scale=sm_scale, causal=True, + ) + assert _cos(out, ref) > 0.999 + + +def test_onlyscore_matches_full_run_and_skips_output(): + # output_maxscore path: score is independent of output_o, O is skipped when + # output_o=False, and the score equals a Torch per-tile max-logit reference. + _need_cuda() + from fmha_sm12x import fmha_sm12x, fmha_sm12x_plan + + torch.manual_seed(1) + dev = torch.device("cuda") + total_q, num_heads, head_dim, page_size = 3, 2, 128, 128 + kv_len = 3 * page_size + cu_q = torch.tensor([0, total_q], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + q = torch.randn(total_q, num_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + k = torch.randn(kv_len, num_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + v = torch.randn_like(k) + sm_scale = 1.0 / math.sqrt(head_dim) + + plan = fmha_sm12x_plan( + cu_q[1:] - cu_q[:-1], cu_k[1:] - cu_k[:-1], num_heads, num_heads, + page_size=page_size, output_maxscore=True, causal=True, + ) + o_full, score_full = fmha_sm12x(q, k, v, plan, sm_scale=sm_scale, output_o=True, output_maxscore=True) + o_none, score_only = fmha_sm12x(q, k, v, plan, sm_scale=sm_scale, output_o=False, output_maxscore=True) + + assert o_full is not None and o_none is None + torch.testing.assert_close(score_only, score_full, atol=0, rtol=0) + + # Torch reference: per 128-token tile, max causal logit. + n_tiles = kv_len // page_size + ref = torch.full((num_heads, n_tiles, total_q), float("-inf"), device=dev, dtype=torch.float32) + for qi in range(total_q): + vis = (kv_len - total_q) + qi + for h in range(num_heads): + logits = (k[:, h].float() @ q[qi, h].float()) * sm_scale + for t in range(n_tiles): + seg = logits[t * page_size : (t + 1) * page_size] + mask = torch.arange(t * page_size, (t + 1) * page_size, device=dev) <= vis + if mask.any(): + ref[h, t, qi] = seg[mask].max() + finite = torch.isfinite(ref) + torch.testing.assert_close(score_full[finite], ref[finite], atol=2e-2, rtol=2e-2) + + +@pytest.mark.parametrize("override", ["int", "tensor", "none"]) +def test_q_offset_override_matches_explicit_plan(override): + # fmha_sm12x(q_offset_override=X) must equal a plan built with qo_offset=X. + _need_cuda() + from fmha_sm12x import fmha_sm12x, fmha_sm12x_plan + + torch.manual_seed(2) + dev = torch.device("cuda") + total_q, num_heads, head_dim = 4, 2, 16 + kv_len = 12 + cu_q = torch.tensor([0, total_q], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + q = torch.randn(total_q, num_heads, head_dim, device=dev, dtype=torch.bfloat16) + k = torch.randn(kv_len, num_heads, head_dim, device=dev, dtype=torch.bfloat16) + v = torch.randn_like(k) + qo_lens = cu_q[1:] - cu_q[:-1] + kv_lens = cu_k[1:] - cu_k[:-1] + sm_scale = 1.0 / math.sqrt(head_dim) + + if override == "int": + ov = 5 + ref_plan = fmha_sm12x_plan(qo_lens, kv_lens, num_heads, num_heads, qo_offset=5, causal=True) + elif override == "tensor": + ov = torch.tensor([6], dtype=torch.int32) + ref_plan = fmha_sm12x_plan(qo_lens, kv_lens, num_heads, num_heads, qo_offset=ov, causal=True) + else: + ov = None + ref_plan = fmha_sm12x_plan(qo_lens, kv_lens, num_heads, num_heads, causal=True) + + base_plan = fmha_sm12x_plan(qo_lens, kv_lens, num_heads, num_heads, causal=True) + out_ov, _ = fmha_sm12x(q, k, v, base_plan, q_offset_override=ov, sm_scale=sm_scale) + out_ref, _ = fmha_sm12x(q, k, v, ref_plan, sm_scale=sm_scale) + torch.testing.assert_close(out_ov, out_ref, atol=0, rtol=0) + + +def test_paged_sparse_attention_matches_dense(): + # sparse_atten_func paged-KV reference path must equal the dense path for + # the same logical KV (the Triton fast path is dense-only, so this exercises + # the paged Torch fallback). + _need_cuda() + from fmha_sm12x import build_k2q_csr, sparse_atten_func + + torch.manual_seed(3) + dev = torch.device("cuda") + total_q, num_kv_heads, h_ratio, head_dim, page_size, topk = 2, 1, 2, 128, 128, 4 + num_qo_heads = num_kv_heads * h_ratio + n_pages = 4 + kv_len = n_pages * page_size + cu_q = torch.tensor([0, total_q], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + q = torch.randn(total_q, num_qo_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + k_dense = torch.randn(kv_len, num_kv_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + v_dense = torch.randn_like(k_dense) + sm_scale = 1.0 / math.sqrt(head_dim) + + # Same logical KV, paged as [pages, Hkv, page_size, D] with an identity table. + k_paged = k_dense.reshape(n_pages, page_size, num_kv_heads, head_dim).permute(0, 2, 1, 3).contiguous() + v_paged = v_dense.reshape(n_pages, page_size, num_kv_heads, head_dim).permute(0, 2, 1, 3).contiguous() + page_table = torch.arange(n_pages, device=dev, dtype=torch.int32).reshape(1, n_pages) + + q2k = torch.full((num_kv_heads, total_q, topk), -1, dtype=torch.int32, device=dev) + q2k[0, :, 0] = 0 + q2k[0, :, 1] = 2 + row_ptr, q_idx = build_k2q_csr(q2k, cu_q, cu_k, page_size) + common = dict( + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=total_q, max_seqlen_k=kv_len, + blk_kv=page_size, causal=True, softmax_scale=sm_scale, + ) + out_dense = sparse_atten_func(q, k_dense, v_dense, row_ptr, q_idx, topk, **common) + out_paged = sparse_atten_func(q, k_paged, v_paged, row_ptr, q_idx, topk, page_table=page_table, **common) + torch.testing.assert_close(out_paged.float(), out_dense.float(), atol=2e-2, rtol=2e-2) + + +def test_fp4_indexer_block_scores_matches_dequant_reference(): + # fp4_indexer_block_scores (public scales) must equal a Torch dequant + per + # 128-tile max-logit reference. + _need_cuda() + from fmha_sm12x import fp4_indexer_block_scores + from fmha_sm12x._fp4 import _FP4_VALUES + + dev = torch.device("cuda") + groups = 8 # nvfp4 scale groups per 128-dim row + nibble = 2 # -> fp4 value 1.0 + q_fp4 = torch.full((1, 1, 64), nibble | (nibble << 4), dtype=torch.uint8, device=dev) + k_fp4 = torch.full((1, 1, 128, 64), nibble | (nibble << 4), dtype=torch.uint8, device=dev) + q_scale = torch.full((1, 1, groups), 1.0, device=dev).to(torch.float8_e4m3fn) + k_scale = torch.full((1, 1, 128, groups), 1.0, device=dev).to(torch.float8_e4m3fn) + cu_q = torch.tensor([0, 1], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, 128], dtype=torch.int32, device=dev) + page_offsets = torch.tensor([0, 1], dtype=torch.int32, device=dev) + kv_indices = torch.tensor([0], dtype=torch.int32, device=dev) + + scores = fp4_indexer_block_scores( + q_fp4, k_fp4, q_scale, k_scale, cu_q, cu_k, page_offsets, + max_seqlen_q=1, max_seqlen_k=128, kv_indices=kv_indices, fp4_format="nvfp4", causal=False, + ) + # Dequant: every element is fp4 value 1.0 with scale 1.0 -> q=k=1.0 over 128 + # dims, so each logit = 128. Single tile -> max score 128. + val = _FP4_VALUES[nibble] + expected = float(val * val * 128) + assert scores.shape == (1, 1, 1) + torch.testing.assert_close(scores[0, 0, 0], torch.tensor(expected, device=dev), atol=1e-3, rtol=1e-3) + + +def _build_dense_sparse_case(dev, *, dtype=torch.bfloat16): + total_q, num_kv_heads, h_ratio, head_dim, page_size, topk = 3, 1, 4, 128, 128, 4 + num_qo_heads = num_kv_heads * h_ratio + n_pages = 4 + kv_len = n_pages * page_size + cu_q = torch.tensor([0, total_q], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + q = (torch.randn(total_q, num_qo_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3).to(dtype) + k = torch.randn(kv_len, num_kv_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + v = torch.randn_like(k) + q2k = torch.full((num_kv_heads, total_q, topk), -1, dtype=torch.int32, device=dev) + q2k[0, :, 0] = 0 + q2k[0, :, 1] = 3 + return dict( + q=q, k=k, v=v, q2k=q2k, cu_q=cu_q, cu_k=cu_k, topk=topk, + page_size=page_size, kv_len=kv_len, total_q=total_q, + sm_scale=1.0 / math.sqrt(head_dim), + ) + + +def _fp8_available(): + return hasattr(torch, "float8_e4m3fn") + + +def test_sparse_atten_fp8_kv_matches_dequantized_bf16(): + # bf16 Q + fp8 E4M3 K/V cache must equal the explicitly dequantized-bf16 + # K/V (SM12x stages fp8 -> bf16, so this is exact). + _need_cuda() + if not _fp8_available(): + pytest.skip("float8_e4m3fn unavailable") + from fmha_sm12x import build_k2q_csr, sparse_atten_func + + torch.manual_seed(5) + c = _build_dense_sparse_case(torch.device("cuda")) + row_ptr, q_idx = build_k2q_csr(c["q2k"], c["cu_q"], c["cu_k"], c["page_size"]) + k_fp8 = c["k"].to(torch.float8_e4m3fn) + v_fp8 = c["v"].to(torch.float8_e4m3fn) + common = dict( + cu_seqlens_q=c["cu_q"], cu_seqlens_k=c["cu_k"], max_seqlen_q=c["total_q"], + max_seqlen_k=c["kv_len"], blk_kv=c["page_size"], causal=True, softmax_scale=c["sm_scale"], + ) + out_fp8 = sparse_atten_func(c["q"], k_fp8, v_fp8, row_ptr, q_idx, c["topk"], **common) + out_ref = sparse_atten_func( + c["q"], k_fp8.to(torch.bfloat16), v_fp8.to(torch.bfloat16), row_ptr, q_idx, c["topk"], **common + ) + torch.testing.assert_close(out_fp8, out_ref, atol=0, rtol=0) + + +def test_sparse_atten_fp8_qkv_matches_dequantized_bf16(): + # All-fp8 Q/K/V must equal the dequantized-bf16 inputs. + _need_cuda() + if not _fp8_available(): + pytest.skip("float8_e4m3fn unavailable") + from fmha_sm12x import build_k2q_csr, sparse_atten_func + + torch.manual_seed(6) + c = _build_dense_sparse_case(torch.device("cuda")) + row_ptr, q_idx = build_k2q_csr(c["q2k"], c["cu_q"], c["cu_k"], c["page_size"]) + q_fp8 = c["q"].to(torch.float8_e4m3fn) + k_fp8 = c["k"].to(torch.float8_e4m3fn) + v_fp8 = c["v"].to(torch.float8_e4m3fn) + common = dict( + cu_seqlens_q=c["cu_q"], cu_seqlens_k=c["cu_k"], max_seqlen_q=c["total_q"], + max_seqlen_k=c["kv_len"], blk_kv=c["page_size"], causal=True, softmax_scale=c["sm_scale"], + ) + out_fp8 = sparse_atten_func(q_fp8, k_fp8, v_fp8, row_ptr, q_idx, c["topk"], **common) + out_ref = sparse_atten_func( + q_fp8.to(torch.bfloat16), k_fp8.to(torch.bfloat16), v_fp8.to(torch.bfloat16), + row_ptr, q_idx, c["topk"], **common, + ) + torch.testing.assert_close(out_fp8, out_ref, atol=0, rtol=0) + + +def test_sparse_atten_rejects_unsupported_mixed_dtypes(): + # fp16 Q with fp8 K (not an SM100-supported combination) is rejected. + _need_cuda() + if not _fp8_available(): + pytest.skip("float8_e4m3fn unavailable") + from fmha_sm12x import build_k2q_csr, sparse_atten_func + + c = _build_dense_sparse_case(torch.device("cuda"), dtype=torch.float16) + row_ptr, q_idx = build_k2q_csr(c["q2k"], c["cu_q"], c["cu_k"], c["page_size"]) + with pytest.raises(TypeError, match="share a dtype"): + sparse_atten_func( + c["q"], c["k"].to(torch.float8_e4m3fn), c["v"].to(torch.float8_e4m3fn), + row_ptr, q_idx, c["topk"], cu_seqlens_q=c["cu_q"], cu_seqlens_k=c["cu_k"], + max_seqlen_q=c["total_q"], max_seqlen_k=c["kv_len"], blk_kv=c["page_size"], causal=True, + ) + + +def test_sparse_atten_nvfp4_kv_matches_dequantized_bf16(): + # The NVFP4 K/V entry must equal sparse_atten_func on the dequantized BF16 + # K/V (mirrors the SM100 nvfp4-matches-dequant check). + _need_cuda() + from fmha_sm12x import build_k2q_csr, sparse_atten_func, sparse_atten_nvfp4_kv_func + from fmha_sm12x._nvfp4 import dequantize_nvfp4_128x4 + + dev = torch.device("cuda") + head_dim, page_size, topk = 128, 128, 4 + kv_len = page_size # one page + q = torch.ones((1, 1, head_dim), dtype=torch.bfloat16, device=dev) + k_fp4 = torch.full((kv_len, 1, 64), 2 | (2 << 4), dtype=torch.uint8, device=dev) + v_fp4 = torch.full((kv_len, 1, 64), 4 | (4 << 4), dtype=torch.uint8, device=dev) + scale = torch.ones((kv_len, 8), device=dev).to(torch.float8_e4m3fn) + cu_q = torch.tensor([0, 1], dtype=torch.int32, device=dev) + cu_k = torch.tensor([0, kv_len], dtype=torch.int32, device=dev) + q2k = torch.tensor([[[0, -1, -1, -1]]], dtype=torch.int32, device=dev) + row_ptr, q_idx = build_k2q_csr(q2k, cu_q, cu_k, page_size) + common = dict( + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=1, max_seqlen_k=kv_len, + blk_kv=page_size, causal=False, softmax_scale=1.0 / math.sqrt(head_dim), + ) + + out_nvfp4 = sparse_atten_nvfp4_kv_func( + q, k_fp4, v_fp4, scale, scale, None, None, row_ptr, q_idx, topk, **common + ) + logical = (*k_fp4.shape[:-1], head_dim) + k_bf16 = dequantize_nvfp4_128x4(k_fp4, scale, None, original_shape=logical) + v_bf16 = dequantize_nvfp4_128x4(v_fp4, scale, None, original_shape=logical) + out_ref = sparse_atten_func(q, k_bf16, v_bf16, row_ptr, q_idx, topk, **common) + torch.testing.assert_close(out_nvfp4, out_ref, atol=0, rtol=0) + + +def test_sparse_decode_full_matches_torch_sdpa(): + # sparse_decode_atten_func with no block selection == dense causal SDPA over + # seqused_k tokens (paged KV decode reference). + _need_cuda() + from fmha_sm12x import sparse_decode_atten_func + + torch.manual_seed(4) + dev = torch.device("cuda") + num_kv_heads, h_ratio, head_dim, page_size = 1, 2, 128, 128 + num_qo_heads = num_kv_heads * h_ratio + seqlen_q, used = 1, 200 + n_pages = (used + page_size - 1) // page_size + q = torch.randn(seqlen_q, num_qo_heads, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + k = torch.randn(n_pages, num_kv_heads, page_size, head_dim, device=dev, dtype=torch.bfloat16) * 0.3 + v = torch.randn_like(k) + page_table = torch.arange(n_pages, device=dev, dtype=torch.int32).reshape(1, n_pages) + seqused = torch.tensor([used], dtype=torch.int32, device=dev) + sm_scale = 1.0 / math.sqrt(head_dim) + + out = sparse_decode_atten_func( + q, k, v, page_table=page_table, seqused_k=seqused, seqlen_q=seqlen_q, + max_seqlen_k=n_pages * page_size, blk_kv=page_size, causal=True, softmax_scale=sm_scale, + ) + + # Flatten paged KV to [used, Hkv, D] and run dense causal SDPA for the single + # decode token (it sees all `used` tokens). + k_flat = k.permute(0, 2, 1, 3).reshape(n_pages * page_size, num_kv_heads, head_dim)[:used] + v_flat = v.permute(0, 2, 1, 3).reshape(n_pages * page_size, num_kv_heads, head_dim)[:used] + ref = torch.zeros((seqlen_q, num_qo_heads, head_dim), dtype=torch.float32, device=dev) + for h in range(num_qo_heads): + kvh = h // h_ratio + logits = (k_flat[:, kvh].float() @ q[0, h].float()) * sm_scale + ref[0, h] = torch.softmax(logits, dim=0) @ v_flat[:, kvh].float() + torch.testing.assert_close(out.float(), ref, atol=2e-2, rtol=2e-2) diff --git a/tests/test_sm12x_reference.py b/tests/test_sm12x_reference.py new file mode 100644 index 0000000..27e91d4 --- /dev/null +++ b/tests/test_sm12x_reference.py @@ -0,0 +1,577 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import math + +import pytest +import torch + + +def _need_cuda() -> None: + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + + +def _need_sm12x_cuda() -> None: + _need_cuda() + major, minor = torch.cuda.get_device_capability() + if major != 12 or minor not in (0, 1): + pytest.skip("SM120/SM121 CUDA device is required") + + +def test_sm12x_dense_reference_matches_torch() -> None: + _need_cuda() + from fmha_sm12x import fmha_sm12x, fmha_sm12x_plan + + device = torch.device("cuda") + q = torch.randn((3, 2, 16), device=device, dtype=torch.bfloat16) + k = torch.randn((5, 1, 16), device=device, dtype=torch.bfloat16) + v = torch.randn((5, 1, 16), device=device, dtype=torch.bfloat16) + qo = torch.tensor([1, 2], dtype=torch.int32) + kv = torch.tensor([2, 3], dtype=torch.int32) + plan = fmha_sm12x_plan(qo, kv, 2, 1, causal=True) + + out, _ = fmha_sm12x(q, k, v, plan, sm_scale=1.0 / math.sqrt(16)) + + refs = [] + q_start = 0 + k_start = 0 + for q_len, k_len in [(1, 2), (2, 3)]: + per_batch = [] + for local_q in range(q_len): + visible = k_len - q_len + local_q + 1 + per_head = [] + for head in range(2): + logits = torch.matmul( + k[k_start : k_start + visible, 0].float(), + q[q_start + local_q, head].float(), + ) / math.sqrt(16) + per_head.append(torch.matmul(torch.softmax(logits, dim=0), v[k_start : k_start + visible, 0].float())) + per_batch.append(torch.stack(per_head, dim=0)) + refs.append(torch.stack(per_batch, dim=0)) + q_start += q_len + k_start += k_len + expected = torch.cat(refs, dim=0).to(torch.bfloat16) + torch.testing.assert_close(out, expected, atol=2e-2, rtol=2e-2) + + +def test_sm12x_sparse_reference_uses_selected_blocks() -> None: + _need_cuda() + from fmha_sm12x import fmha_sm12x, fmha_sm12x_plan + + device = torch.device("cuda") + page = 2 + q = torch.randn((1, 1, 8), device=device, dtype=torch.bfloat16) + k = torch.randn((4, 1, 8), device=device, dtype=torch.bfloat16) + v = torch.randn((4, 1, 8), device=device, dtype=torch.bfloat16) + blocks = torch.tensor([[[1]]], device=device, dtype=torch.int32) + plan = fmha_sm12x_plan(torch.tensor([1], dtype=torch.int32), torch.tensor([4], dtype=torch.int32), 1, 1, qo_offset=3, page_size=page, kv_block_num=1) + + out, _ = fmha_sm12x(q, k, v, plan, kv_block_indexes=blocks, sm_scale=1.0 / math.sqrt(8)) + + logits = torch.matmul(k[2:4, 0].float(), q[0, 0].float()) / math.sqrt(8) + expected = torch.matmul(torch.softmax(logits, dim=0), v[2:4, 0].float()).to(torch.bfloat16) + torch.testing.assert_close(out[0, 0], expected, atol=2e-2, rtol=2e-2) + + +def test_sm12x_reference_rejects_plan_total_q_mismatch() -> None: + _need_cuda() + from fmha_sm12x import fmha_sm12x, fmha_sm12x_plan + + q = torch.randn((2, 1, 16), device="cuda", dtype=torch.bfloat16) + k = torch.randn((1, 1, 16), device="cuda", dtype=torch.bfloat16) + v = torch.randn_like(k) + plan = fmha_sm12x_plan( + torch.tensor([1], dtype=torch.int32), + torch.tensor([1], dtype=torch.int32), + 1, + 1, + causal=False, + ) + + with pytest.raises(ValueError, match="sum\\(qo_segment_lens\\)"): + fmha_sm12x(q, k, v, plan) + + +def test_sm12x_build_k2q_rejects_cross_batch_block_index() -> None: + _need_cuda() + from fmha_sm12x import build_k2q_csr + + q2k = torch.tensor([[[1], [0]]], device="cuda", dtype=torch.int32) + cu_q = torch.tensor([0, 1, 2], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 2, 4], device="cuda", dtype=torch.int32) + + with pytest.raises(ValueError, match="batch 0"): + build_k2q_csr(q2k, cu_q, cu_k, 2) + + +def test_sm12x_sparse_rejects_invalid_csr_query_index() -> None: + _need_cuda() + from fmha_sm12x import sparse_atten_func + + q = torch.randn((1, 1, 16), device="cuda", dtype=torch.bfloat16) + k = torch.randn((2, 1, 16), device="cuda", dtype=torch.bfloat16) + v = torch.randn_like(k) + row_ptr = torch.tensor([[0, 1]], device="cuda", dtype=torch.int32) + q_indices = torch.tensor([[-1]], device="cuda", dtype=torch.int32) + cu_q = torch.tensor([0, 1], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 2], device="cuda", dtype=torch.int32) + + with pytest.raises(ValueError, match="q index"): + sparse_atten_func( + q, + k, + v, + row_ptr, + q_indices, + 1, + cu_seqlens_q=cu_q, + cu_seqlens_k=cu_k, + max_seqlen_q=1, + max_seqlen_k=2, + blk_kv=2, + causal=False, + ) + + +def test_sm12x_sparse_atten_temperature_lse_parity() -> None: + # Parity with SM100 sparse_atten_func: return_temperature_lse yields a + # 3-tuple (out, lse, temperature_lse); with lse_temperature_scale=1.0 the + # temperature LSE equals the plain LSE, and a >1 scale shrinks it. + _need_cuda() + from fmha_sm12x import sparse_atten_func + + torch.manual_seed(0) + q = torch.randn((2, 1, 16), device="cuda", dtype=torch.bfloat16) + k = torch.randn((2, 1, 16), device="cuda", dtype=torch.bfloat16) + v = torch.randn_like(k) + row_ptr = torch.tensor([[0, 2]], device="cuda", dtype=torch.int32) + q_indices = torch.tensor([[0, 1]], device="cuda", dtype=torch.int32) + cu_q = torch.tensor([0, 2], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 2], device="cuda", dtype=torch.int32) + common = dict( + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=2, max_seqlen_k=2, + blk_kv=2, causal=False, softmax_scale=0.5, return_softmax_lse=True, + ) + + out, lse = sparse_atten_func(q, k, v, row_ptr, q_indices, 1, **common) + out3, lse3, temp_lse = sparse_atten_func( + q, k, v, row_ptr, q_indices, 1, return_temperature_lse=True, + lse_temperature_scale=1.0, **common, + ) + torch.testing.assert_close(out3, out, atol=0, rtol=0) + torch.testing.assert_close(lse3, lse, atol=0, rtol=0) + torch.testing.assert_close(temp_lse, lse, atol=1e-5, rtol=1e-5) + + _, _, temp_lse2 = sparse_atten_func( + q, k, v, row_ptr, q_indices, 1, return_temperature_lse=True, + lse_temperature_scale=4.0, **common, + ) + assert temp_lse2.shape == lse.shape + assert bool(torch.isfinite(temp_lse2).all().item()) + + with pytest.raises(ValueError, match="return_temperature_lse"): + sparse_atten_func( + q, k, v, row_ptr, q_indices, 1, return_temperature_lse=True, + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=2, max_seqlen_k=2, + blk_kv=2, return_softmax_lse=False, + ) + + +def test_sm12x_topk_export_runs() -> None: + _need_cuda() + from fmha_sm12x import sparse_topk_select + + scores = torch.full((1, 8, 1), -1000.0, device="cuda", dtype=torch.float32) + scores[0, 3, 0] = 1.0 + out = sparse_topk_select(scores.contiguous(), 16, num_valid_pages=8) + assert out.shape == (1, 1, 16) + selected = set(out[0, 0].to("cpu", non_blocking=False).tolist()) + assert 3 in selected + + +def _packed_fp4(shape: tuple[int, ...], nibble: int) -> torch.Tensor: + return torch.full(shape, int(nibble) | (int(nibble) << 4), dtype=torch.uint8, device="cuda") + + +def test_sm12x_fp4_indexer_block_scores_runs() -> None: + _need_cuda() + from fmha_sm12x import fp4_indexer_block_scores + + q = _packed_fp4((1, 1, 64), 2) + k = _packed_fp4((1, 1, 128, 64), 2) + q_scale = torch.ones((1, 1, 8), device="cuda").to(torch.float8_e4m3fn) + k_scale = torch.ones((1, 1, 128, 8), device="cuda").to(torch.float8_e4m3fn) + cu_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + cu_k = torch.tensor([0, 128], dtype=torch.int32, device="cuda") + page_offsets = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + kv_indices = torch.tensor([0], dtype=torch.int32, device="cuda") + + scores = fp4_indexer_block_scores( + q, k, q_scale, k_scale, cu_q, cu_k, page_offsets, + max_seqlen_q=1, max_seqlen_k=128, kv_indices=kv_indices, + fp4_format="nvfp4", scale_layout="public", + ) + + assert scores.shape == (1, 1, 1) + torch.testing.assert_close(scores[0, 0, 0], torch.tensor(128.0, device="cuda")) + + +def test_sm12x_nvfp4_helper_exports_dequantize() -> None: + _need_cuda() + from fmha_sm12x import ( + Nvfp4QuantizedTensor, + dequantize_nvfp4_128x4_to_bf16, + nvfp4_global_scale_from_amax, + swizzle_nvfp4_scale_to_128x4, + ) + + data = _packed_fp4((1, 1, 64), 2) + scale = swizzle_nvfp4_scale_to_128x4( + torch.ones((1, 8), device="cuda").to(torch.float8_e4m3fn), + rows=1, + cols=8, + ) + global_scale = torch.ones((1,), device="cuda", dtype=torch.float32) + quantized = Nvfp4QuantizedTensor( + data=data, + scale_128x4=scale, + global_scale=global_scale, + logical_scale_shape=(1, 8), + original_shape=(1, 1, 128), + ) + + out = dequantize_nvfp4_128x4_to_bf16(quantized) + + torch.testing.assert_close(out, torch.ones_like(out), atol=0, rtol=0) + expected_scale = torch.ones((1,), device="cuda", dtype=torch.float32) + torch.testing.assert_close(nvfp4_global_scale_from_amax(torch.full_like(expected_scale, 2688.0)), expected_scale) + + +def test_sm12x_build_k2q_csr_uses_sm100_block_major_row_order() -> None: + from fmha_sm12x import build_k2q_csr + + q2k = torch.tensor([[[1], [-1], [0], [0]]], dtype=torch.int32) + cu_q = torch.tensor([0, 2, 4], dtype=torch.int32) + cu_k = torch.tensor([0, 2, 4], dtype=torch.int32) + + row_ptr, q_indices = build_k2q_csr(q2k, cu_q, cu_k, 1) + + torch.testing.assert_close(row_ptr, torch.tensor([[0, 0, 2, 3, 3]], dtype=torch.int32)) + torch.testing.assert_close(q_indices[:, :3], torch.tensor([[0, 1, 0]], dtype=torch.int32)) + + +def test_sm12x_optimized_k2q_builder_matches_reference() -> None: + _need_sm12x_cuda() + from fmha_sm12x import SparseK2qCsrBuilderSm12x, build_k2q_csr + + q2k = torch.tensor( + [[[0, -1, -1, -1], [0, -1, -1, -1], [0, 1, -1, -1], [1, -1, -1, -1]]], + device="cuda", + dtype=torch.int32, + ) + cu_q = torch.tensor([0, 2, 4], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 128, 384], device="cuda", dtype=torch.int32) + + ref_row_ptr, ref_q_indices = build_k2q_csr(q2k, cu_q, cu_k, 128) + row_ptr, q_indices = SparseK2qCsrBuilderSm12x()(q2k, cu_q, cu_k, total_k=384, blk_kv=128) + torch.cuda.synchronize() + + torch.testing.assert_close(row_ptr, ref_row_ptr) + nnz = int(row_ptr[0, -1].item()) + ref_nnz = int(ref_row_ptr[0, -1].item()) + torch.testing.assert_close(q_indices[:, :nnz], ref_q_indices[:, :ref_nnz]) + torch.testing.assert_close(q_indices[:, nnz:], torch.full_like(q_indices[:, nnz:], -1)) + + +def test_sm12x_optimized_k2q_builder_handles_per_batch_partial_rows() -> None: + _need_sm12x_cuda() + from fmha_sm12x import SparseK2qCsrBuilderSm12x, build_k2q_csr + + q2k = torch.tensor( + [[[0, -1, -1, -1], [0, -1, -1, -1]]], + device="cuda", + dtype=torch.int32, + ) + cu_q = torch.tensor([0, 1, 2], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 1, 128], device="cuda", dtype=torch.int32) + + ref_row_ptr, ref_q_indices = build_k2q_csr(q2k, cu_q, cu_k, 128) + row_ptr, q_indices = SparseK2qCsrBuilderSm12x()(q2k, cu_q, cu_k, total_k=128, blk_kv=128) + torch.cuda.synchronize() + + torch.testing.assert_close(row_ptr, ref_row_ptr) + nnz = int(row_ptr[0, -1].item()) + torch.testing.assert_close(q_indices[:, :nnz], ref_q_indices[:, :nnz]) + + +def test_sm12x_optimized_k2q_builder_rejects_bad_cu_seqlens() -> None: + _need_sm12x_cuda() + from fmha_sm12x import SparseK2qCsrBuilderSm12x + + q2k = torch.zeros((1, 2, 4), device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 128], device="cuda", dtype=torch.int32) + bad_cu_q = torch.tensor([0, 1, 1], device="cuda", dtype=torch.int32) + + with pytest.raises(ValueError, match="cu_seqlens_q"): + SparseK2qCsrBuilderSm12x()(q2k, bad_cu_q, cu_k, total_k=128, blk_kv=128) + + +def test_sm12x_optimized_k2q_builder_returns_schedule() -> None: + _need_sm12x_cuda() + from fmha_sm12x import SparseK2qCsrBuilderSm12x + + q2k = torch.tensor( + [[[0, -1, -1, -1], [0, -1, -1, -1], [0, 1, -1, -1], [1, -1, -1, -1]]], + device="cuda", + dtype=torch.int32, + ) + cu_q = torch.tensor([0, 2, 4], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 128, 384], device="cuda", dtype=torch.int32) + + row_ptr, q_indices, schedule = SparseK2qCsrBuilderSm12x()( + q2k, + cu_q, + cu_k, + total_k=384, + blk_kv=128, + max_seqlen_q=2, + max_seqlen_k=256, + return_schedule=True, + ) + torch.cuda.synchronize() + + assert schedule.enabled + assert schedule.scheduler_metadata is not None + assert schedule.scheduler_metadata.shape[1] == 6 + assert schedule.work_count is not None + assert int(schedule.work_count.item()) >= 0 + assert schedule.qsplit_indices is not None + assert schedule.qsplit_indices.shape == q_indices.shape + assert schedule.split_counts is not None + assert schedule.split_counts.shape == (4, 1) + assert int(row_ptr[0, -1].item()) == 5 + + +def test_sm12x_sparse_nvfp4_prefill_runs() -> None: + _need_cuda() + from fmha_sm12x import sparse_atten_nvfp4_kv_func + + q = torch.ones((1, 1, 128), dtype=torch.bfloat16, device="cuda") + k = _packed_fp4((128, 1, 64), 2) + v = _packed_fp4((128, 1, 64), 4) + scale = torch.ones((128, 8), device="cuda").to(torch.float8_e4m3fn) + row_ptr = torch.tensor([[0, 1]], dtype=torch.int32, device="cuda") + q_indices = torch.tensor([[0]], dtype=torch.int32, device="cuda") + cu_q = torch.tensor([0, 1], dtype=torch.int32, device="cuda") + cu_k = torch.tensor([0, 128], dtype=torch.int32, device="cuda") + + out, lse = sparse_atten_nvfp4_kv_func( + q, k, v, scale, scale, None, None, row_ptr, q_indices, 1, + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=1, max_seqlen_k=128, + causal=False, return_softmax_lse=True, + ) + + expected_lse = torch.tensor(math.sqrt(128.0) + math.log(128.0), device="cuda", dtype=torch.float32) + torch.testing.assert_close(out, torch.full_like(out, 2.0), atol=0, rtol=0) + torch.testing.assert_close(lse, expected_lse.reshape(1, 1), atol=1e-5, rtol=1e-5) + + # Parity: the NVFP4 variant forwards temperature-LSE outputs (3-tuple). + out3, lse3, temp_lse = sparse_atten_nvfp4_kv_func( + q, k, v, scale, scale, None, None, row_ptr, q_indices, 1, + cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, max_seqlen_q=1, max_seqlen_k=128, + causal=False, return_softmax_lse=True, return_temperature_lse=True, + lse_temperature_scale=1.0, + ) + torch.testing.assert_close(out3, out, atol=0, rtol=0) + torch.testing.assert_close(temp_lse, lse3, atol=1e-5, rtol=1e-5) + + +def test_sm12x_build_k2q_csr_reference_rejects_return_schedule() -> None: + # The pure-Torch reference cannot emit the fused schedule; it must fail + # clearly rather than return a 2-tuple a 3-tuple caller would mis-unpack. + _need_cuda() + from fmha_sm12x import build_k2q_csr + + q2k = torch.tensor([[[0, -1, -1, -1]]], device="cuda", dtype=torch.int32) + cu_q = torch.tensor([0, 1], device="cuda", dtype=torch.int32) + cu_k = torch.tensor([0, 128], device="cuda", dtype=torch.int32) + with pytest.raises(ValueError, match="return_schedule"): + build_k2q_csr(q2k, cu_q, cu_k, 128, return_schedule=True) + + +def test_sm12x_decode_schedule_helper_runs() -> None: + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.decode_schedule import prepare_decode_schedule + + seqused_k = torch.tensor([128, 256], device="cuda", dtype=torch.int32) + schedule = prepare_decode_schedule( + seqused_k=seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=256, + disable_split_kv=True, + ) + torch.cuda.synchronize() + + assert schedule.request_indices.is_cuda + assert schedule.request_indices.shape[0] == schedule.padded_work_count + assert schedule.block_valid_mask.shape[0] == schedule.padded_work_count + assert schedule.merge_indptr.ndim == 1 + assert schedule.o_indptr.ndim == 1 + assert int(schedule.merge_indptr[0].item()) == 0 + assert int(schedule.o_indptr[0].item()) == 0 + assert schedule.split_counts.shape == (2,) + assert schedule.work_count > 0 + + +def test_sm12x_decode_schedule_cuda_graph_padding_capacity() -> None: + # Regression: when CUDA-graph capture pads work_count up to the captured + # grid, padded_work_count can exceed the page-based allocation bound + # (batch * num_q_tiles * max_pages_global). The output index arrays must + # be allocated to at least the graph pad or the wrapper's + # narrow(0, 0, padded_work_count) reads past the allocation. A small + # max_seqlen_k with a large max_grid_size override forces that case. + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.decode_schedule import prepare_decode_schedule + + seqused_k = torch.tensor([1152], device="cuda", dtype=torch.int32) + grid_override = 4096 # >> batch * num_q_tiles * max_pages_global (= 9) + schedule = prepare_decode_schedule( + seqused_k=seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=1152, + enable_cuda_graph=True, + max_grid_size=grid_override, + ) + torch.cuda.synchronize() + + assert schedule.split_kv + # CUDA-graph pad = max_grid_size / num_kv_heads, far above the 9-tile + # page bound; the narrow below would have raised before the fix. + assert schedule.padded_work_count >= grid_override + for arr in ( + schedule.request_indices, + schedule.qo_tile_indices, + schedule.kv_tile_indices, + schedule.block_valid_mask, + ): + assert arr.shape[0] == schedule.padded_work_count + # Entries past the real work_count must be zeroed padding (valid mask 0). + assert int(schedule.block_valid_mask[schedule.padded_work_count - 1].item()) == 0 + + +def test_sm12x_decode_schedule_rejects_too_small_max_seqlen_k() -> None: + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.decode_schedule import prepare_decode_schedule + + seqused_k = torch.tensor([256], device="cuda", dtype=torch.int32) + with pytest.raises(ValueError, match="max_seqlen_k"): + prepare_decode_schedule( + seqused_k=seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=128, + ) + + +def test_sm12x_decode_schedule_raw_entrypoint_guards_hang() -> None: + # Blocker: the raw launch wrapper must reject hang-inducing seqused_k + # (seqused_k[b] < seqlen_q) on its own, so a direct caller that bypasses + # prepare_decode_schedule cannot spin the kernel on an all-masked row. + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.fwd_decode.build_decode_schedule import ( + build_decode_schedule, + ) + + seqused_k = torch.tensor([4], device="cuda", dtype=torch.int32) # < seqlen_q + with pytest.raises(ValueError, match="seqused_k"): + build_decode_schedule( + seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=128, + ) + + +def test_sm12x_decode_schedule_raw_entrypoint_guards_pad_overflow() -> None: + # Blocker: the raw wrapper must reject seqused_k longer than max_seqlen_k, + # since the work-tile arrays are sized from max_seqlen_k; otherwise the + # kernel scatter / narrow(0, 0, padded_work_count) run out of bounds. + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.fwd_decode.build_decode_schedule import ( + build_decode_schedule, + ) + + seqused_k = torch.tensor([256], device="cuda", dtype=torch.int32) + with pytest.raises(ValueError, match="max_seqlen_k"): + build_decode_schedule( + seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=128, + ) + + +def test_sm12x_decode_schedule_rejects_short_seqused_k_via_wrapper() -> None: + # The high-level wrapper still surfaces the same guard (now enforced at the + # raw boundary it funnels through). + _need_sm12x_cuda() + from fmha_sm12x.cute.src.sm12x.decode_schedule import prepare_decode_schedule + + seqused_k = torch.tensor([4], device="cuda", dtype=torch.int32) + with pytest.raises(ValueError, match="seqused_k"): + prepare_decode_schedule( + seqused_k=seqused_k, + page_size=128, + seqlen_q=8, + num_qo_heads=16, + num_kv_heads=1, + head_dim=128, + max_seqlen_k=128, + ) + + +def test_sm12x_sparse_decode_wrapper_runs() -> None: + _need_cuda() + from fmha_sm12x import SparseDecodePagedAttentionWrapper, sparse_decode_atten_func + + q = torch.ones((1, 1, 128), dtype=torch.bfloat16, device="cuda") + k = torch.zeros((1, 1, 128, 128), dtype=torch.bfloat16, device="cuda") + v = torch.zeros_like(k) + v[0, 0, 0].fill_(2.0) + v[0, 0, 1].fill_(4.0) + page_table = torch.tensor([[0]], dtype=torch.int32, device="cuda") + seqused = torch.tensor([2], dtype=torch.int32, device="cuda") + + direct, lse = sparse_decode_atten_func( + q, k, v, page_table=page_table, seqused_k=seqused, + seqlen_q=1, max_seqlen_k=2, blk_kv=128, causal=True, + return_softmax_lse=True, + ) + wrapper = SparseDecodePagedAttentionWrapper().plan( + page_table=page_table, seqused_k=seqused, seqlen_q=1, max_seqlen_k=2, + num_qo_heads=1, num_kv_heads=1, head_dim=128, + ) + wrapped = wrapper.run(q, k, v) + + torch.testing.assert_close(direct, torch.full_like(direct, 3.0), atol=0, rtol=0) + torch.testing.assert_close(lse, torch.full_like(lse, math.log(2.0)), atol=1e-6, rtol=1e-6) + torch.testing.assert_close(wrapped, direct, atol=0, rtol=0) diff --git a/tests/test_sm12x_triton_sparse.py b/tests/test_sm12x_triton_sparse.py new file mode 100644 index 0000000..2bc7851 --- /dev/null +++ b/tests/test_sm12x_triton_sparse.py @@ -0,0 +1,151 @@ +# SPDX-License-Identifier: MIT + +from __future__ import annotations + +import math + +import pytest +import torch + + +def _need_triton_sm12x(): + if not torch.cuda.is_available(): + pytest.skip("CUDA is required") + major, _ = torch.cuda.get_device_capability() + if major != 12: + pytest.skip("SM120/SM121 CUDA device is required") + from fmha_sm12x._triton_sparse import triton_available + + if not triton_available(): + pytest.skip("Triton is required") + + +def _reference(q, k, v, block_indexes, *, cu_q, cu_k, num_kv_heads, page_size, causal, sm_scale, want_temp, temp): + from fmha_sm12x._lse import run_lse + from fmha_sm12x.api import fmha_sm12x, fmha_sm12x_plan + + qo_lens = cu_q[1:] - cu_q[:-1] + kv_lens = cu_k[1:] - cu_k[:-1] + plan = fmha_sm12x_plan( + qo_lens, kv_lens, int(q.shape[1]), int(num_kv_heads), page_size=int(page_size), causal=bool(causal) + ) + out, _ = fmha_sm12x(q, k, v, plan, kv_indices=None, kv_block_indexes=block_indexes, sm_scale=sm_scale) + lse = run_lse(q, k, v, plan, kv_indices=None, kv_block_indexes=block_indexes, sm_scale=sm_scale) + tlse = None + if want_temp: + tlse = run_lse(q, k, v, plan, kv_indices=None, kv_block_indexes=block_indexes, sm_scale=sm_scale / temp) + return out, lse, tlse + + +@pytest.mark.parametrize( + "batches,h_ratio,num_kv_heads,topk,page_size,causal,temp", + [ + ([(2, 256), (3, 384)], 16, 1, 2, 128, True, 1.0), + ([(4, 128)], 4, 2, 4, 64, True, 2.0), + ([(1, 64), (2, 96)], 1, 3, 2, 16, False, 1.0), + ([(3, 200)], 8, 1, 1, 128, True, 4.0), + ([(2, 130), (2, 70)], 2, 2, 4, 64, False, 1.0), + ], +) +def test_triton_sparse_matches_reference(batches, h_ratio, num_kv_heads, topk, page_size, causal, temp): + _need_triton_sm12x() + from fmha_sm12x._triton_sparse import triton_sparse_atten_dense + + device = torch.device("cuda") + torch.manual_seed(1234 + topk + page_size) + num_qo_heads = num_kv_heads * h_ratio + head_dim = 128 + + qo_lens = [b[0] for b in batches] + kv_lens = [b[1] for b in batches] + cu_q = torch.tensor([0, *torch.tensor(qo_lens).cumsum(0).tolist()], dtype=torch.int32, device=device) + cu_k = torch.tensor([0, *torch.tensor(kv_lens).cumsum(0).tolist()], dtype=torch.int32, device=device) + total_q = int(cu_q[-1].item()) + total_k = int(cu_k[-1].item()) + + q = torch.randn((total_q, num_qo_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + k = torch.randn((total_k, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + v = torch.randn((total_k, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + sm_scale = 1.0 / math.sqrt(head_dim) + + # Build per-(query, kv_head) block selections within each batch's KV, with + # some -1 padding and an occasional fully-padded (no-block) query. + block_indexes = torch.full((total_q, num_kv_heads, topk), -1, dtype=torch.int32, device=device) + rng = torch.Generator(device="cpu").manual_seed(7) + qo_cpu = cu_q.cpu().tolist() + for b in range(len(batches)): + n_blocks = (kv_lens[b] + page_size - 1) // page_size + for qi in range(qo_cpu[b], qo_cpu[b + 1]): + for kh in range(num_kv_heads): + if torch.rand(1, generator=rng).item() < 0.1: + continue # leave fully padded + n_sel = int(torch.randint(1, topk + 1, (1,), generator=rng).item()) + perm = torch.randperm(n_blocks, generator=rng)[:n_sel] + block_indexes[qi, kh, : perm.numel()] = perm.to(torch.int32).to(device) + + ref_out, ref_lse, ref_tlse = _reference( + q, k, v, block_indexes, cu_q=cu_q, cu_k=cu_k, num_kv_heads=num_kv_heads, + page_size=page_size, causal=causal, sm_scale=sm_scale, want_temp=(temp != 1.0), temp=temp, + ) + out, lse, tlse = triton_sparse_atten_dense( + q, k, v, block_indexes, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + num_kv_heads=num_kv_heads, page_size=page_size, causal=causal, sm_scale=sm_scale, + return_lse=True, lse_temperature_scale=temp, return_temperature_lse=(temp != 1.0), + ) + + torch.testing.assert_close(out.float(), ref_out.float(), atol=2e-2, rtol=2e-2) + + # Compare LSE only where the reference has finite mass (fully-padded + # queries are -inf in both and would fail allclose on inf arithmetic). + finite = torch.isfinite(ref_lse) + torch.testing.assert_close(lse[finite], ref_lse[finite], atol=2e-2, rtol=2e-2) + assert bool((torch.isfinite(lse) == finite).all().item()) + if temp != 1.0: + finite_t = torch.isfinite(ref_tlse) + torch.testing.assert_close(tlse[finite_t], ref_tlse[finite_t], atol=2e-2, rtol=2e-2) + + +def test_triton_sparse_duplicate_and_fully_masked_blocks(): + # Lock two edge invariants against the reference: a block id duplicated in + # a topk row (double-counted identically by both paths) and a selected + # block entirely beyond an early query's causal window (no visible position + # -> zero output, -inf LSE). qo_len == kv_len makes query 0's causal limit + # 0, so any block it selects past block 0 is fully masked. + _need_triton_sm12x() + from fmha_sm12x._triton_sparse import triton_sparse_atten_dense + + device = torch.device("cuda") + torch.manual_seed(99) + head_dim, page_size, num_kv_heads, h_ratio = 64, 16, 1, 2 + num_qo_heads = num_kv_heads * h_ratio + seq = 48 # qo_len == kv_len == 48 -> 3 blocks of 16 + cu_q = torch.tensor([0, seq], dtype=torch.int32, device=device) + cu_k = torch.tensor([0, seq], dtype=torch.int32, device=device) + q = torch.randn((seq, num_qo_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + k = torch.randn((seq, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + v = torch.randn((seq, num_kv_heads, head_dim), device=device, dtype=torch.bfloat16) * 0.3 + sm_scale = 1.0 / math.sqrt(head_dim) + + block_indexes = torch.full((seq, num_kv_heads, 3), -1, dtype=torch.int32, device=device) + block_indexes[0, 0, 0] = 2 # query 0 causal limit 0: block 2 fully masked + block_indexes[20, 0, 0] = 0 + block_indexes[20, 0, 1] = 0 # duplicate, finite mass + block_indexes[40, 0, 0] = 1 + block_indexes[40, 0, 1] = 2 + + ref_out, ref_lse, _ = _reference( + q, k, v, block_indexes, cu_q=cu_q, cu_k=cu_k, num_kv_heads=num_kv_heads, + page_size=page_size, causal=True, sm_scale=sm_scale, want_temp=False, temp=1.0, + ) + out, lse, _ = triton_sparse_atten_dense( + q, k, v, block_indexes, cu_seqlens_q=cu_q, cu_seqlens_k=cu_k, + num_kv_heads=num_kv_heads, page_size=page_size, causal=True, sm_scale=sm_scale, + return_lse=True, out_dtype=torch.bfloat16, + ) + torch.testing.assert_close(out.float(), ref_out.float(), atol=2e-2, rtol=2e-2) + # query 0 selects only a fully-masked block: zero output, -inf LSE. + assert bool((out[0] == 0).all().item()) + assert bool(torch.isneginf(lse[0]).all().item()) + finite = torch.isfinite(ref_lse) + torch.testing.assert_close(lse[finite], ref_lse[finite], atol=2e-2, rtol=2e-2) + assert bool((torch.isfinite(lse) == finite).all().item())