Add SM12x MSA Support (Tested on DGX Spark, linux/arm64/SM121)#1
Open
dbotwinick wants to merge 5 commits into
Open
Add SM12x MSA Support (Tested on DGX Spark, linux/arm64/SM121)#1dbotwinick wants to merge 5 commits into
dbotwinick wants to merge 5 commits into
Conversation
Resolve the nvcc gencode list, cpp_extension arch flag, and a per-arch JIT cache suffix from MSA_CUDA_ARCH / MSA_NVCC_GENCODES env overrides or the detected device (SM100/SM103/SM120/SM121). Provides require_sm100_csrc_arch / require_sm12x_csrc_arch guards so each family's csrc kernels only build for their supported targets.
Parallel namespace to fmha_sm100 for GB10 (SM120/SM121), which lacks the SM100 tcgen05/TMEM instructions. Ships real portable CUDA helpers (q2k->k2q CSR builder, paged-decode split-KV scheduler, and a sparse_topk loader that compiles fmha_sm100's shared kernel source for the SM12x target), a semi-optimized Triton block-sparse prefill kernel (BF16/FP16, with FP8 E4M3 and NVFP4 K/V staged to BF16) that falls back to a Torch reference, and Torch reference paths for dense attention, the FP4 indexer, NVFP4 quant/dequant, and paged decode. Mirrors the fmha_sm100 public surface; fmha_sm100 is unchanged. Triton is imported lazily and is not a declared dependency.
Replace the explicit package list with packages.find (include = fmha_sm100, fmha_sm12x*, minimax_msa) so fmha_sm12x's importable cute/src subpackages are picked up automatically while fmha_sm100 stays a single package (its cute/ and cutlass/ trees ship as package data). Add fmha_sm12x cute sources to package-data.
test_arch (arch-helper unit tests), test_sm12x_reference (reference/optimized helpers), test_sm12x_triton_sparse (Triton vs Torch reference), and test_sm12x_equivalence (SM100-behaviour parity: proxy-KV E2E, onlyscore, q_offset override, paged sparse, FP4 indexer, FP8/NVFP4 K/V, decode-vs-SDPA). conftest gates by device: SM100/SM103 runs the SM100 suite and skips the SM12x suite, every other device does the reverse; the forced-topk smoke test selects the package by arch.
Describe the companion fmha_sm12x package: portable CUDA helpers, the Triton block-sparse prefill kernel with Torch reference fallback, FP8/NVFP4 K/V staged to BF16, and Triton as an optional (non-declared) dependency.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This pull request introduces support for NVIDIA SM120/SM121 architectures (tested on SM121, GB10) by adding a new
fmha_sm12xpackage alongside the existing SM100 code. The documentation is updated to clarify architecture support and package separation, and the Python packaging is adjusted to include the new subpackage and its data. Two new reference implementations for SM12x (paged decode and FP4 indexer) are also added.Note: Triton is the higher performing path (although not fully optimized) with pure torch fallback.
SM120/SM121 (GB10) support and new package:
fmha_sm12xPython package as a parallel namespace for SM120/SM121, exposing a public API facade and lazy attribute loading for its components. This package includes support for dense and sparse attention, paged decode, and quantization/dequantization utilities, mirroring thefmha_sm100API where possible._decode.py), providing a correctness-first Torch-based path for block-sparse decode matching the SM100 surface._fp4.py) for dequantizing FP4 tensors and computing per-page QK max scores, supporting both MXFP4 and NVFP4 formats.Documentation updates:
README.mdto clarify the separation between SM100/SM103 and SM120/SM121 support, explain the purpose of the newfmha_sm12xpackage, and describe the architecture-specific kernels, helper utilities, and fallback mechanisms. Also added a new section on the SM12x status and usage. [1] [2] [3] [4] [5]Packaging and distribution:
pyproject.tomlto auto-discover and include the newfmha_sm12xpackage and its data files, ensuring all necessary CUDA and Python sources are shipped for SM12x support. [1] [2]