feat: dequantize MLX 4/8-bit checkpoints in the loader (#449 M3 Stage 2d)#489
Closed
inureyes wants to merge 1 commit into
Closed
feat: dequantize MLX 4/8-bit checkpoints in the loader (#449 M3 Stage 2d)#489inureyes wants to merge 1 commit into
inureyes wants to merge 1 commit into
Conversation
… 2d)
The OpenXLA loader read only bf16/f16/f32, so the many MLX-quantized
checkpoints (stored as U32-packed weights with companion `*.scales` /
`*.biases`) could not run on the XLA path, even for the already-supported
Llama / Qwen2 architectures. Add affine dequantization at load.
- `Config` gains `quantization: Option<QuantConfig{bits, group_size}>`,
read from `config.json`'s `quantization` block. The emitted graph is
unchanged (it runs in f32); only the loader differs.
- `weights::dequantize_affine` unpacks the U32 weight (low-order-first,
`32/bits` values per word) and applies `q * scale + bias` per group of
`group_size` input columns, the MLX affine layout.
- `load_weights` dequantizes any `U32` weight using its same-shard
`*.scales` / `*.biases` (f16), producing the `[out, in]` f32 the graph
expects; layernorms and q/k/v biases are f16 and fall through to the
widen path. 4-bit and 8-bit are supported.
Validation:
- The formula is verified against real data without mlx: dequantizing
Qwen2.5-0.5B-4bit and correlating with the original bf16 weights gives
0.995+ across embed / attention / MLP (a wrong unpacking would give ~0).
- Unit test: a hand-built u32 row dequantizes to its exact expected f32.
- E2E, CUDA on GB10: XLA running Qwen2.5-0.5B-4bit (the loader's dequant)
is token-exact (40/40) with an HF fp32 oracle on the same weights
dequantized offline. Llama-3.2-1B-4bit (a different arch, tied) generates
coherent text, confirming the dequant is architecture-independent.
Dequantizing to f32 trades the on-disk storage saving for running the
quantized checkpoint on the f32 graph; fused quantized matmul is a later
optimization (Stage 3).
ed8ed24 to
ee3036d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The OpenXLA loader read only bf16/f16/f32, so the many MLX-quantized checkpoints (U32-packed weights + companion
*.scales/*.biases) couldn't run on the XLA path, even for the already-supported Llama / Qwen2 arches. This adds affine dequantization at load (4-bit and 8-bit), so quantized checkpoints run on those arches. Part of #449 M3 Stage 2d.What changed (
mlxcel-xla)Configgainsquantization: Option<QuantConfig{bits, group_size}>fromconfig.json. The emitted graph is unchanged (it runs in f32); only the loader differs.weights::dequantize_affineunpacks the U32 weight (low-order-first,32/bitsvalues per word) and appliesq*scale + biasper group ofgroup_sizeinput columns (the MLX affine layout).load_weightsdequantizes anyU32weight using its same-shard*.scales/*.biases(f16) → the[out, in]f32 the graph expects; layernorms and q/k/v biases stay f16.Validation
Notes
Refs #449.