Skip to content

Add SM12x MSA Support (Tested on DGX Spark, linux/arm64/SM121)#1

Open
dbotwinick wants to merge 5 commits into
MiniMax-AI:mainfrom
dbotwinick:sm12x-parity
Open

Add SM12x MSA Support (Tested on DGX Spark, linux/arm64/SM121)#1
dbotwinick wants to merge 5 commits into
MiniMax-AI:mainfrom
dbotwinick:sm12x-parity

Conversation

@dbotwinick

@dbotwinick dbotwinick commented Jun 11, 2026

Copy link
Copy Markdown

This pull request introduces support for NVIDIA SM120/SM121 architectures (tested on SM121, GB10) by adding a new fmha_sm12x package alongside the existing SM100 code. The documentation is updated to clarify architecture support and package separation, and the Python packaging is adjusted to include the new subpackage and its data. Two new reference implementations for SM12x (paged decode and FP4 indexer) are also added.

Note: Triton is the higher performing path (although not fully optimized) with pure torch fallback.

SM120/SM121 (GB10) support and new package:

  • Added the new fmha_sm12x Python package as a parallel namespace for SM120/SM121, exposing a public API facade and lazy attribute loading for its components. This package includes support for dense and sparse attention, paged decode, and quantization/dequantization utilities, mirroring the fmha_sm100 API where possible.
  • Implemented a reference SM12x paged decode kernel and wrapper (_decode.py), providing a correctness-first Torch-based path for block-sparse decode matching the SM100 surface.
  • Added a reference SM12x FP4 indexer (_fp4.py) for dequantizing FP4 tensors and computing per-page QK max scores, supporting both MXFP4 and NVFP4 formats.

Documentation updates:

  • Updated README.md to clarify the separation between SM100/SM103 and SM120/SM121 support, explain the purpose of the new fmha_sm12x package, and describe the architecture-specific kernels, helper utilities, and fallback mechanisms. Also added a new section on the SM12x status and usage. [1] [2] [3] [4] [5]

Packaging and distribution:

  • Modified pyproject.toml to auto-discover and include the new fmha_sm12x package and its data files, ensuring all necessary CUDA and Python sources are shipped for SM12x support. [1] [2]

Resolve the nvcc gencode list, cpp_extension arch flag, and a per-arch JIT cache suffix from MSA_CUDA_ARCH / MSA_NVCC_GENCODES env overrides or the detected device (SM100/SM103/SM120/SM121). Provides require_sm100_csrc_arch / require_sm12x_csrc_arch guards so each family's csrc kernels only build for their supported targets.
Parallel namespace to fmha_sm100 for GB10 (SM120/SM121), which lacks the SM100 tcgen05/TMEM instructions. Ships real portable CUDA helpers (q2k->k2q CSR builder, paged-decode split-KV scheduler, and a sparse_topk loader that compiles fmha_sm100's shared kernel source for the SM12x target), a semi-optimized Triton block-sparse prefill kernel (BF16/FP16, with FP8 E4M3 and NVFP4 K/V staged to BF16) that falls back to a Torch reference, and Torch reference paths for dense attention, the FP4 indexer, NVFP4 quant/dequant, and paged decode. Mirrors the fmha_sm100 public surface; fmha_sm100 is unchanged.

Triton is imported lazily and is not a declared dependency.
Replace the explicit package list with packages.find (include = fmha_sm100, fmha_sm12x*, minimax_msa) so fmha_sm12x's importable cute/src subpackages are picked up automatically while fmha_sm100 stays a single package (its cute/ and cutlass/ trees ship as package data). Add fmha_sm12x cute sources to package-data.
test_arch (arch-helper unit tests), test_sm12x_reference (reference/optimized helpers), test_sm12x_triton_sparse (Triton vs Torch reference), and test_sm12x_equivalence (SM100-behaviour parity: proxy-KV E2E, onlyscore, q_offset override, paged sparse, FP4 indexer, FP8/NVFP4 K/V, decode-vs-SDPA). conftest gates by device: SM100/SM103 runs the SM100 suite and skips the SM12x suite, every other device does the reverse; the forced-topk smoke test selects the package by arch.
Describe the companion fmha_sm12x package: portable CUDA helpers, the Triton block-sparse prefill kernel with Torch reference fallback, FP8/NVFP4 K/V staged to BF16, and Triton as an optional (non-declared) dependency.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant