Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 46 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

[![License: MIT](https://img.shields.io/badge/license-MIT-blue.svg)](LICENSE)
[![Python](https://img.shields.io/badge/python-≥3.10-blue.svg)](pyproject.toml)
[![GPU](https://img.shields.io/badge/NVIDIA-SM100-76b900.svg)](#requirements)
[![GPU](https://img.shields.io/badge/NVIDIA-SM100%20%2F%20SM103-76b900.svg)](#requirements)
[![Stack: CuTe-DSL + Cuda](https://img.shields.io/badge/stack-CuTe--DSL%20%2B%20Cuda-purple.svg)](#stacks)

**MSA** (`fmha_sm100`) ships dense FlashAttention and sparse top-k attention
kernels for **NVIDIA SM100**. Two JIT-compiled stacks
share one Python package:
**MSA** (`fmha_sm100`, plus the companion `fmha_sm12x` namespace) ships
dense FlashAttention and sparse top-k attention kernels for **NVIDIA SM100 /
SM103**. Those kernels use SM100-only tcgen05/TMEM instructions, so SM120 /
SM121 (GB10) is served by the separate `fmha_sm12x` package rather than
aliasing the SM100 path:

![MSA architecture](docs/architecture.png)

Expand All @@ -25,16 +27,16 @@ share one Python package:

## Requirements

- **GPU**: NVIDIA SM100.
- **GPU**: NVIDIA SM100 / SM103 for the `fmha_sm100` kernels; SM120 / SM121 (GB10) is served by the `fmha_sm12x` package (portable CUDA helpers + a Triton block-sparse prefill kernel, with Torch reference fallbacks).
- **Toolchain**: CUDA Toolkit with `nvcc` on `PATH` (or `CUDA_HOME` / `CUDA_PATH` set).
- **Python**: ≥ 3.10.
- **OS**: Linux x86_64 (aarch64 untested; JIT builds may need small Makefile edits on WSL).

Quick sanity check before installing:

```bash
nvcc --version # expect ≥ 12.x
nvidia-smi --query-gpu=compute_cap --format=csv | grep "10.0" # confirm SM100
nvcc --version # expect ≥ 12.x for SM100, CUDA 13.x for SM120/SM121
nvidia-smi --query-gpu=compute_cap --format=csv | grep -E "10\.(0|3)|12\.(0|1)"
python -c "import sys; print(sys.version_info[:2])" # ≥ (3, 10)
```

Expand All @@ -56,6 +58,26 @@ This pulls in the CuTe-DSL stack via `nvidia-cutlass-dsl` and `quack-kernels`;
the csrc kernels are JIT-compiled at first import from sources shipped inside
the package.

By default, csrc JIT builds preserve the original SM100/SM103 targets. The
`fmha_sm12x` package is a parallel SM120/SM121 namespace: it ships real
portable CUDA helpers (the k2q CSR builder, paged-decode split-KV scheduler,
and top-k selector), a semi-optimized Triton kernel for block-sparse prefill
attention (BF16/FP16, plus FP8 E4M3 and NVFP4 K/V staged to BF16), and
correctness-first Torch references for the rest (dense attention, FP4 block
scoring, paged decode). The
Triton path is optional — Triton ships transitively with `torch` on Linux, is
imported lazily, and falls back to the Torch reference when unavailable, so it
is not a declared dependency. Do not compile the existing `fmha_sm100` kernels
for GB10 / SM121; they rely on SM100-only tcgen05/TMEM operations. For SM12x
development, set the target before importing compiled SM12x kernel modules so
caches are partitioned by architecture:

```bash
export MSA_CUDA_ARCH=sm_121
# For custom multi-target builds, override the full nvcc gencode list:
# export MSA_NVCC_GENCODES='-gencode=arch=compute_120,code=sm_120 -gencode=arch=compute_121,code=sm_121'
```

## Verify

Run a small CUDA smoke test. **The first run JIT-compiles `sparse_topk_select`,
Expand Down Expand Up @@ -160,11 +182,12 @@ python/fmha_sm100/ Python package
api.py fmha_sm100 / fmha_sm100_plan / sparse_topk_select
jit.py Runtime JIT (nvcc + ninja) for the csrc stack
sparse.py Lazy shim that loads the cute/ stack
sparse_fmha_adapter.py Bridge: fmha_sm100 API sparse_atten_func
sparse_fmha_adapter.py Bridge: fmha_sm100 API -> sparse_atten_func
csrc/ CUDA kernels + Jinja templates (JIT-compiled)
include/ Vendored FlashInfer / CUTLASS-derived / TRT-LLM headers
cutlass/ NVIDIA CUTLASS git submodule (include/ + tools/util/include/)
cute/ CuTe-DSL sparse attention (loaded via sys.path)
fmha_sm12x/ SM120/SM121 attention/indexer/decode namespace (Triton + Torch)
tests/ Correctness tests
smoke/ integration/ regression/
scripts/ Warmup + cache-management helpers
Expand All @@ -184,6 +207,21 @@ benchmarks/ bench_sparse_attention_ops.py
site to the sparse backend for prefill paths; useful when you already
drive the dense kernel and want a one-line swap to sparse.

### SM12x status

The SM100 tcgen05/TMEM kernels stay isolated; `fmha_sm12x` exposes independent
SM120/SM121-safe routes that mirror the `fmha_sm100` public surface. The
portable helpers are real optimized CUDA kernels: the q2k→k2q CSR builder, the
paged-decode split-KV scheduler, and the `sparse_topk_select` indexer. Block-
sparse prefill attention — including the FP8 E4M3 and NVFP4 K/V variants, which
are staged/dequantized to BF16 — runs on a semi-optimized Triton kernel for
dense KV, used by default when Triton is importable and falling back to the
Torch reference otherwise. Dense FMHA, the FP4 indexer block scores, and paged
decode (BF16, or FP8 staged to BF16) remain correctness-first Torch
references. These routes cover the full API and are validated on GB10;
matching SM100's fused-kernel throughput would still need dedicated SM12x
CUTLASS/CuTe kernels.

## Third-party licenses

`fmha_sm100` bundles, derives from, or depends on the third-party components
Expand Down
21 changes: 17 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,15 @@ dependencies = [
# which adds that directory to sys.path and re-exports the public sparse API.
[tool.setuptools]
package-dir = { "" = "python" }
# The cute/ sparse sources are loaded via sys.path.insert at runtime by
# fmha_sm100/sparse.py (not imported as a submodule), so they are shipped as
# package data.
packages = ["fmha_sm100"]

# Auto-discover the packages. fmha_sm12x's cute/src loaders are real importable
# subpackages, so the glob pulls them all in without listing each one.
# fmha_sm100 stays a single package on purpose: its cute/ + cutlass/ trees are
# loaded via sys.path at runtime and ship as package data (below), so they must
# NOT be discovered as importable packages.
[tool.setuptools.packages.find]
where = ["python"]
include = ["fmha_sm100", "fmha_sm12x*", "minimax_msa"]

[tool.setuptools.package-data]
fmha_sm100 = [
Expand All @@ -54,3 +59,11 @@ fmha_sm100 = [
"cutlass/include/**/*",
"cutlass/tools/util/include/**/*",
]

fmha_sm12x = [
"cute/**/*.py",
"cute/**/*.cu",
"cute/**/*.cuh",
"cute/**/*.h",
"cute/**/*.hpp",
]
82 changes: 82 additions & 0 deletions python/fmha_sm12x/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# SPDX-License-Identifier: MIT

"""SM120/SM121 package facade for MiniMax Sparse Attention."""

from __future__ import annotations

from minimax_msa.arch import (
CudaArch,
cpp_extension_arch_flag,
cuda_arch_cache_suffix,
nvcc_gencode_flags,
selected_cuda_arch,
)

_API_EXPORTS = frozenset(
{
"Sm12xPlan",
"fmha_sm12x",
"fmha_sm12x_plan",
"sparse_topk_select",
}
)
_SPARSE_EXPORTS = frozenset(
{
"Nvfp4QuantizedTensor",
"SparseDecodePagedAttentionWrapper",
"SparseK2qCsrBuilderSm12x",
"build_k2q_csr",
"dequantize_nvfp4_128x4_to_bf16",
"fp4_indexer_block_scores",
"nvfp4_global_scale_from_amax",
"nvfp4_scale_128x4_offset",
"quantize_bf16_to_nvfp4_128x4",
"quantize_kv_bf16_to_nvfp4_128x4",
"sparse_atten_func",
"sparse_atten_nvfp4_kv_func",
"sparse_decode_atten_func",
"swizzle_nvfp4_scale_to_128x4",
}
)

__all__ = [
"CudaArch",
"Nvfp4QuantizedTensor",
"Sm12xPlan",
"SparseDecodePagedAttentionWrapper",
"SparseK2qCsrBuilderSm12x",
"build_k2q_csr",
"cpp_extension_arch_flag",
"cuda_arch_cache_suffix",
"dequantize_nvfp4_128x4_to_bf16",
"fmha_sm12x",
"fmha_sm12x_plan",
"fp4_indexer_block_scores",
"nvcc_gencode_flags",
"nvfp4_global_scale_from_amax",
"nvfp4_scale_128x4_offset",
"quantize_bf16_to_nvfp4_128x4",
"quantize_kv_bf16_to_nvfp4_128x4",
"selected_cuda_arch",
"sparse_atten_func",
"sparse_atten_nvfp4_kv_func",
"sparse_decode_atten_func",
"sparse_topk_select",
"swizzle_nvfp4_scale_to_128x4",
]


def __getattr__(name: str):
if name in _API_EXPORTS:
from . import api as _api

return getattr(_api, name)
if name in _SPARSE_EXPORTS:
from . import sparse as _sparse

return getattr(_sparse, name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")


def __dir__() -> list[str]:
return sorted({*globals(), *__all__})
131 changes: 131 additions & 0 deletions python/fmha_sm12x/_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# SPDX-License-Identifier: MIT

"""SM12x paged decode reference implementation."""

from __future__ import annotations

import torch

from ._lse import run_lse
from .api import fmha_sm12x, fmha_sm12x_plan


def _compact_page_table(page_table: torch.Tensor, seqused_k: torch.Tensor, page_size: int) -> torch.Tensor:
pages: list[torch.Tensor] = []
lengths = seqused_k.to("cpu", dtype=torch.int64, non_blocking=False).tolist()
for batch, kv_len in enumerate(lengths):
page_count = (int(kv_len) + int(page_size) - 1) // int(page_size)
if page_count > 0:
pages.append(page_table[batch, :page_count])
if not pages:
return torch.empty((0,), dtype=torch.int32, device=page_table.device)
return torch.cat(pages).contiguous()


def sparse_decode_atten_func(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
q2k_indices: torch.Tensor | None = None,
*,
page_table: torch.Tensor,
seqused_k: torch.Tensor,
seqlen_q: int,
max_seqlen_k: int,
blk_kv: int = 128,
causal: bool = True,
softmax_scale: float | None = None,
return_softmax_lse: bool = False,
**_kwargs,
):
"""Run SM12x paged decode via the dequantized Torch reference path."""

if q.ndim != 3 or k.ndim != 4 or v.ndim != 4:
raise ValueError("decode expects q [B*S,Hq,D] and paged k/v [P,Hkv,page,D]")
# FP8 E4M3 K/V cache (and FP8 Q) is staged to BF16, matching the SM100
# decode path; the reference then runs in BF16.
q = q.to(torch.bfloat16) if q.dtype == torch.float8_e4m3fn else q
k = k.to(torch.bfloat16) if k.dtype == torch.float8_e4m3fn else k
v = v.to(torch.bfloat16) if v.dtype == torch.float8_e4m3fn else v
if page_table.dtype != torch.int32 or seqused_k.dtype != torch.int32:
raise TypeError("page_table and seqused_k must be int32")
if page_table.device != q.device or seqused_k.device != q.device:
raise ValueError("decode metadata must be on q.device")
batch = int(page_table.shape[0])
if int(q.shape[0]) != batch * int(seqlen_q):
raise ValueError("q.shape[0] must equal batch * seqlen_q")
if int(k.shape[2]) != int(blk_kv) or k.shape != v.shape:
raise ValueError("paged k/v shapes must match and use blk_kv page size")
if max_seqlen_k <= 0:
raise ValueError("max_seqlen_k must be positive")
qo_lens = torch.full((batch,), int(seqlen_q), dtype=torch.int32, device=q.device)
kv_lens = seqused_k.to(dtype=torch.int32)
kv_indices = _compact_page_table(page_table.contiguous(), seqused_k.contiguous(), int(blk_kv))
plan = fmha_sm12x_plan(
qo_lens, kv_lens, int(q.shape[1]), int(k.shape[1]), page_size=int(blk_kv),
causal=bool(causal),
)
block_indexes = None
if q2k_indices is not None:
if q2k_indices.dtype != torch.int32 or q2k_indices.ndim != 3:
raise ValueError("q2k_indices must be int32 with shape [Hkv, total_q, topK]")
block_indexes = q2k_indices.permute(1, 0, 2).contiguous()
out, _ = fmha_sm12x(
q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes,
sm_scale=softmax_scale,
)
if return_softmax_lse:
lse = run_lse(
q, k, v, plan, kv_indices=kv_indices, kv_block_indexes=block_indexes,
sm_scale=softmax_scale,
)
return out, lse
return out


class SparseDecodePagedAttentionWrapper:
"""Plan/run wrapper matching the SM100 decode surface for SM12x."""

def __init__(self, *, blk_kv: int = 128, causal: bool = True) -> None:
self.blk_kv = int(blk_kv)
self.causal = bool(causal)
self.page_table: torch.Tensor | None = None
self.seqused_k: torch.Tensor | None = None
self.q2k_indices: torch.Tensor | None = None
self.seqlen_q: int | None = None
self.max_seqlen_k: int | None = None

def plan(
self,
*,
page_table: torch.Tensor,
seqused_k: torch.Tensor,
seqlen_q: int,
max_seqlen_k: int,
q2k_indices: torch.Tensor | None = None,
**_kwargs,
) -> "SparseDecodePagedAttentionWrapper":
self.page_table = page_table.contiguous()
self.seqused_k = seqused_k.contiguous()
self.q2k_indices = None if q2k_indices is None else q2k_indices.contiguous()
self.seqlen_q = int(seqlen_q)
self.max_seqlen_k = int(max_seqlen_k)
return self

def run(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
*,
softmax_scale: float | None = None,
return_softmax_lse: bool = False,
**_kwargs,
):
if self.page_table is None or self.seqused_k is None or self.seqlen_q is None or self.max_seqlen_k is None:
raise RuntimeError("decode wrapper must be planned before run")
return sparse_decode_atten_func(
q, k, v, self.q2k_indices, page_table=self.page_table, seqused_k=self.seqused_k,
seqlen_q=self.seqlen_q, max_seqlen_k=self.max_seqlen_k, blk_kv=self.blk_kv,
causal=self.causal, softmax_scale=softmax_scale, return_softmax_lse=return_softmax_lse,
)
Loading