Skip to content

feat: dequantize MLX 4/8-bit checkpoints in the loader (#449 M3 Stage 2d)#490

Merged
inureyes merged 1 commit into
mainfrom
feat/449-xla-4bit-dequant
Jun 30, 2026
Merged

feat: dequantize MLX 4/8-bit checkpoints in the loader (#449 M3 Stage 2d)#490
inureyes merged 1 commit into
mainfrom
feat/449-xla-4bit-dequant

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

The OpenXLA loader read only bf16/f16/f32, so the many MLX-quantized checkpoints (U32-packed weights + companion *.scales / *.biases) couldn't run on the XLA path, even for the already-supported Llama / Qwen2 arches. This adds affine dequantization at load (4-bit and 8-bit). Part of #449 M3 Stage 2d. (Supersedes #489, which auto-closed when its stacked base branch merged; rebased onto main.)

What changed (mlxcel-xla)

  • Config gains quantization: Option<QuantConfig{bits, group_size}> from config.json. The emitted graph is unchanged (it runs in f32); only the loader differs.
  • weights::dequantize_affine unpacks the U32 weight (low-order-first, 32/bits values per word) and applies q*scale + bias per group of group_size input columns (the MLX affine layout).
  • load_weights dequantizes any U32 weight using its same-shard *.scales/*.biases (f16) → the [out, in] f32 the graph expects; layernorms and q/k/v biases stay f16.

Validation

  • Formula vs real data (no mlx needed): dequantizing Qwen2.5-0.5B-4bit and correlating with the original bf16 weights gives 0.995+ across embed / attention / MLP (a wrong unpacking would give ~0).
  • Unit tests: hand-built u32 rows dequantize to their exact expected f32 for both 4-bit and 8-bit.
  • E2E, CUDA on GB10: XLA running Qwen2.5-0.5B-4bit (the loader's dequant) is token-exact (40/40) with an HF fp32 oracle on the same weights dequantized offline. Llama-3.2-1B-4bit (different arch, tied) generates coherent text → the dequant is architecture-independent.

Notes

  • Dequantizing to f32 trades the on-disk storage saving for running the quantized checkpoint on the f32 graph. Fused quantized matmul (weights packed on-device) is a later optimization (Stage 3).

Refs #449.

… 2d)

The OpenXLA loader read only bf16/f16/f32, so the many MLX-quantized
checkpoints (stored as U32-packed weights with companion `*.scales` /
`*.biases`) could not run on the XLA path, even for the already-supported
Llama / Qwen2 architectures. Add affine dequantization at load.

- `Config` gains `quantization: Option<QuantConfig{bits, group_size}>`,
  read from `config.json`'s `quantization` block. The emitted graph is
  unchanged (it runs in f32); only the loader differs.
- `weights::dequantize_affine` unpacks the U32 weight (low-order-first,
  `32/bits` values per word) and applies `q * scale + bias` per group of
  `group_size` input columns, the MLX affine layout.
- `load_weights` dequantizes any `U32` weight using its same-shard
  `*.scales` / `*.biases` (f16), producing the `[out, in]` f32 the graph
  expects; layernorms and q/k/v biases are f16 and fall through to the
  widen path. 4-bit and 8-bit are supported.

Validation:
- The formula is verified against real data without mlx: dequantizing
  Qwen2.5-0.5B-4bit and correlating with the original bf16 weights gives
  0.995+ across embed / attention / MLP (a wrong unpacking would give ~0).
- Unit test: a hand-built u32 row dequantizes to its exact expected f32.
- E2E, CUDA on GB10: XLA running Qwen2.5-0.5B-4bit (the loader's dequant)
  is token-exact (40/40) with an HF fp32 oracle on the same weights
  dequantized offline. Llama-3.2-1B-4bit (a different arch, tied) generates
  coherent text, confirming the dequant is architecture-independent.

Dequantizing to f32 trades the on-disk storage saving for running the
quantized checkpoint on the f32 graph; fused quantized matmul is a later
optimization (Stage 3).
@inureyes inureyes added area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions labels Jun 30, 2026
@inureyes inureyes merged commit 150cee0 into main Jun 30, 2026
5 checks passed
@inureyes inureyes deleted the feat/449-xla-4bit-dequant branch June 30, 2026 12:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant