From 3e3a71243b2e65bdbd9f14d59bb6e4258a79fd36 Mon Sep 17 00:00:00 2001 From: Jeongkyu Shin Date: Tue, 30 Jun 2026 15:23:45 +0900 Subject: [PATCH] feat: dequantize MLX 4/8-bit checkpoints in the loader (#449 M3 Stage 2d) The OpenXLA loader read only bf16/f16/f32, so the many MLX-quantized checkpoints (stored as U32-packed weights with companion `*.scales` / `*.biases`) could not run on the XLA path, even for the already-supported Llama / Qwen2 architectures. Add affine dequantization at load. - `Config` gains `quantization: Option`, read from `config.json`'s `quantization` block. The emitted graph is unchanged (it runs in f32); only the loader differs. - `weights::dequantize_affine` unpacks the U32 weight (low-order-first, `32/bits` values per word) and applies `q * scale + bias` per group of `group_size` input columns, the MLX affine layout. - `load_weights` dequantizes any `U32` weight using its same-shard `*.scales` / `*.biases` (f16), producing the `[out, in]` f32 the graph expects; layernorms and q/k/v biases are f16 and fall through to the widen path. 4-bit and 8-bit are supported. Validation: - The formula is verified against real data without mlx: dequantizing Qwen2.5-0.5B-4bit and correlating with the original bf16 weights gives 0.995+ across embed / attention / MLP (a wrong unpacking would give ~0). - Unit test: a hand-built u32 row dequantizes to its exact expected f32. - E2E, CUDA on GB10: XLA running Qwen2.5-0.5B-4bit (the loader's dequant) is token-exact (40/40) with an HF fp32 oracle on the same weights dequantized offline. Llama-3.2-1B-4bit (a different arch, tied) generates coherent text, confirming the dequant is architecture-independent. Dequantizing to f32 trades the on-disk storage saving for running the quantized checkpoint on the f32 graph; fused quantized matmul is a later optimization (Stage 3). --- src/lib/mlxcel-xla/src/emitter/config.rs | 34 ++++++++ src/lib/mlxcel-xla/src/emitter/mod.rs | 1 + src/lib/mlxcel-xla/src/iree.rs | 66 +++++++++++++-- src/lib/mlxcel-xla/src/weights.rs | 102 +++++++++++++++++++++++ 4 files changed, 196 insertions(+), 7 deletions(-) diff --git a/src/lib/mlxcel-xla/src/emitter/config.rs b/src/lib/mlxcel-xla/src/emitter/config.rs index 2679bc13..3577982d 100644 --- a/src/lib/mlxcel-xla/src/emitter/config.rs +++ b/src/lib/mlxcel-xla/src/emitter/config.rs @@ -25,6 +25,16 @@ pub enum RopeScaling { }, } +/// MLX affine weight quantization (`config.json` `quantization`). The linear / +/// embedding `*.weight` tensors are stored packed as `U32` with companion +/// `*.scales` / `*.biases`; the loader dequantizes them to f32 as +/// `q * scale + bias` per group of `group_size` input columns. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct QuantConfig { + pub bits: usize, + pub group_size: usize, +} + #[derive(Clone, Debug)] pub struct Config { pub hidden: usize, @@ -46,6 +56,10 @@ pub struct Config { /// projection; `false` adds a separate `params['lm_head']` weight the tail /// projects through instead (Llama-3.1-8B, larger Qwen2.5 sizes). pub tie_word_embeddings: bool, + /// MLX affine weight quantization, if the checkpoint is quantized (`None` for + /// an unquantized bf16/f16/f32 checkpoint). The graph itself is unchanged (it + /// runs in f32); the loader dequantizes the packed weights at load. + pub quantization: Option, } impl Config { @@ -69,6 +83,7 @@ impl Config { }, qkv_bias: false, tie_word_embeddings: true, + quantization: None, } } @@ -120,6 +135,24 @@ impl Config { .and_then(serde_json::Value::as_bool) .unwrap_or(true); + // MLX affine quantization: an optional `{bits, group_size}` block. The + // loader dequantizes the packed weights; the emitted graph is unchanged. + let quantization = match v.get("quantization") { + None | Some(serde_json::Value::Null) => None, + Some(q) => { + let qu = |k: &str| -> Result { + q.get(k) + .and_then(serde_json::Value::as_u64) + .map(|x| x as usize) + .ok_or_else(|| format!("config.json quantization missing integer `{k}`")) + }; + Some(QuantConfig { + bits: qu("bits")?, + group_size: qu("group_size")?, + }) + } + }; + let u = |k: &str| -> Result { v.get(k) .and_then(serde_json::Value::as_u64) @@ -193,6 +226,7 @@ impl Config { rope, qkv_bias, tie_word_embeddings, + quantization, }) } diff --git a/src/lib/mlxcel-xla/src/emitter/mod.rs b/src/lib/mlxcel-xla/src/emitter/mod.rs index 21ea69b7..648f6663 100644 --- a/src/lib/mlxcel-xla/src/emitter/mod.rs +++ b/src/lib/mlxcel-xla/src/emitter/mod.rs @@ -83,6 +83,7 @@ mod tests { rope: RopeScaling::Plain, qkv_bias, tie_word_embeddings: true, + quantization: None, } } diff --git a/src/lib/mlxcel-xla/src/iree.rs b/src/lib/mlxcel-xla/src/iree.rs index 0e8ed8e1..fb4361de 100644 --- a/src/lib/mlxcel-xla/src/iree.rs +++ b/src/lib/mlxcel-xla/src/iree.rs @@ -20,10 +20,11 @@ //! [`Config`], (2) emits the `prefill` / `decode_step` StableHLO graphs from that //! config (the #451 emitter, ending in an on-device argmax) and compiles them to //! vmfbs with the dist's `iree-compile`, cached by content hash, (3) loads the -//! weights as f32 (widening bf16 / f16, or copying f32) in the emitter's arg -//! order, from either a single-file `model.safetensors` or a sharded checkpoint -//! (via its `model.safetensors.index.json`, which is how the big untied models -//! ship), and (4) hands all of it to the +//! weights as f32 (widening bf16 / f16, copying f32, or dequantizing MLX 4 / 8-bit +//! affine weights) in the emitter's arg order, from either a single-file +//! `model.safetensors` or a sharded checkpoint (via its +//! `model.safetensors.index.json`, which is how the big untied models ship), and +//! (4) hands all of it to the //! shim, which keeps the weights resident on the device and threads the KV cache //! across steps. Then [`IreeLlama::prefill`] / [`IreeLlama::decode`] are token-in //! / token-out. Emitting from config (issue #449 M3 Stage 2d) replaced the bundled @@ -50,7 +51,7 @@ use memmap2::Mmap; use safetensors::{Dtype, SafeTensors}; use crate::emitter::{Config, emit_decode, emit_decode_ragged, emit_prefill}; -use crate::weights::{bf16_to_f32, f16_to_f32, f32_le_to_f32}; +use crate::weights::{bf16_to_f32, dequantize_affine, f16_to_f32, f32_le_to_f32}; /// Prefill bucket baked into the emitted `prefill` graph (`tensor<256xi32>`, /// == MAX_SEQ, so it covers any prompt the 256-slot KV cache holds). @@ -385,6 +386,57 @@ fn load_weights( let t = st .tensor(name) .map_err(|e| format!("weight {name} in {}: {e}", shard.display()))?; + + // MLX affine-quantized weight: a `U32`-packed `[out, in_packed]` weight + // with same-shard `*.scales` / `*.biases`. Dequantize to `[out, in]` f32 + // (the graph's dtype); the layernorms and q/k/v biases are not quantized + // and fall through to the widen path below. + if t.dtype() == Dtype::U32 { + let qc = cfg.quantization.ok_or_else(|| { + format!( + "weight {name} is U32 (quantized) but config.json has no `quantization`" + ) + })?; + let prefix = name + .strip_suffix(".weight") + .ok_or_else(|| format!("quantized weight {name} does not end in `.weight`"))?; + let scales_name = format!("{prefix}.scales"); + let biases_name = format!("{prefix}.biases"); + let scales = st.tensor(&scales_name).map_err(|e| { + format!("{scales_name} (for {name}) in {}: {e}", shard.display()) + })?; + let biases = st.tensor(&biases_name).map_err(|e| { + format!("{biases_name} (for {name}) in {}: {e}", shard.display()) + })?; + if scales.dtype() != Dtype::F16 || biases.dtype() != Dtype::F16 { + return Err(format!( + "{prefix} scales/biases dtype {:?}/{:?}, expected F16", + scales.dtype(), + biases.dtype() + )); + } + let shape = t.shape(); + if shape.len() != 2 { + return Err(format!("quantized weight {name} rank {} != 2", shape.len())); + } + let (out, in_packed) = (shape[0], shape[1]); + let in_ = in_packed * (32 / qc.bits); + bufs[i] = dequantize_affine( + t.data(), + scales.data(), + biases.data(), + out, + in_packed, + qc.bits, + qc.group_size, + ) + .map_err(|e| format!("dequantize {name}: {e}"))?; + ranks[i] = 2; + dims[i * 4] = out as i64; + dims[i * 4 + 1] = in_ as i64; + continue; + } + // Widen to f32 (the graph's weight dtype). bf16 and f16 are the common // checkpoint dtypes; f32 is a passthrough. The widening is exact for // all three, so it matches HF's f32 reference. @@ -394,8 +446,8 @@ fn load_weights( Dtype::F32 => f32_le_to_f32(t.data()), other => { return Err(format!( - "weight {name} dtype {other:?}, expected BF16/F16/F32 \ - (a quantized checkpoint must be dequantized first)" + "weight {name} dtype {other:?}, expected BF16/F16/F32 or \ + MLX-quantized U32" )); } }; diff --git a/src/lib/mlxcel-xla/src/weights.rs b/src/lib/mlxcel-xla/src/weights.rs index eb34e2ce..d3e8b72f 100644 --- a/src/lib/mlxcel-xla/src/weights.rs +++ b/src/lib/mlxcel-xla/src/weights.rs @@ -62,6 +62,74 @@ pub(crate) fn f32_le_to_f32(bytes: &[u8]) -> Vec { .collect() } +/// Dequantize one MLX affine-quantized weight to row-major `[out, in]` f32. +/// +/// `packed` is the row-major `[out, in_packed]` u32 weight (little-endian bytes, +/// `in_packed = in * bits / 32`); `scales` / `biases` are the row-major +/// `[out, in/group_size]` f16 buffers. Each weight is recovered as +/// `w[o,i] = q[o,i] * scale[o, i/group_size] + bias[o, i/group_size]`, where `q` +/// is the `bits`-wide value unpacked low-order-first from `packed[o, i/(32/bits)]` +/// (the MLX affine layout). The graph runs in f32, so the packed weights are +/// widened here once at load. +pub(crate) fn dequantize_affine( + packed: &[u8], + scales: &[u8], + biases: &[u8], + out: usize, + in_packed: usize, + bits: usize, + group_size: usize, +) -> Result, String> { + if !(bits == 4 || bits == 8) { + return Err(format!( + "unsupported quantization bits {bits} (expected 4 or 8)" + )); + } + let per_u32 = 32 / bits; // values packed per u32 + let in_ = in_packed * per_u32; + if group_size == 0 || !in_.is_multiple_of(group_size) { + return Err(format!( + "quantization group_size {group_size} does not divide in dimension {in_}" + )); + } + let n_groups = in_ / group_size; + if packed.len() != out * in_packed * 4 { + return Err(format!( + "packed weight is {} bytes, expected {} ([{out}, {in_packed}] u32)", + packed.len(), + out * in_packed * 4 + )); + } + let scales = f16_to_f32(scales); + let biases = f16_to_f32(biases); + if scales.len() != out * n_groups || biases.len() != out * n_groups { + return Err(format!( + "scales/biases have {}/{} elements, expected {} ([{out}, {n_groups}])", + scales.len(), + biases.len(), + out * n_groups + )); + } + let mask: u32 = (1u32 << bits) - 1; + let mut w = vec![0f32; out * in_]; + for o in 0..out { + let row = &packed[o * in_packed * 4..(o + 1) * in_packed * 4]; + let grow = o * n_groups; + let wrow = o * in_; + for p in 0..in_packed { + let u = + u32::from_le_bytes([row[p * 4], row[p * 4 + 1], row[p * 4 + 2], row[p * 4 + 3]]); + for j in 0..per_u32 { + let i = p * per_u32 + j; + let q = ((u >> (bits * j)) & mask) as f32; + let g = i / group_size; + w[wrow + i] = q * scales[grow + g] + biases[grow + g]; + } + } + } + Ok(w) +} + #[cfg(test)] mod tests { use super::*; @@ -115,4 +183,38 @@ mod tests { let bytes = 1.5f32.to_le_bytes(); assert_eq!(f32_le_to_f32(&bytes), vec![1.5]); } + + /// 4-bit affine dequant on a hand-built row: one u32 packs eight nibbles + /// 1..=8 (low-order first), two groups of 4 with scale/bias (2.0, +10) and + /// (0.5, -1), so `q*scale + bias` is exact. + #[test] + fn dequantize_affine_recovers_hand_example() { + // u32 = 0x8765_4321 -> nibbles [1,2,3,4,5,6,7,8] low-order first. + let packed = [0x21u8, 0x43, 0x65, 0x87]; + let scales = [0x00u8, 0x40, 0x00, 0x38]; // f16 [2.0, 0.5] + let biases = [0x00u8, 0x49, 0x00, 0xBC]; // f16 [10.0, -1.0] + let w = dequantize_affine(&packed, &scales, &biases, 1, 1, 4, 4).unwrap(); + assert_eq!(w, vec![12.0, 14.0, 16.0, 18.0, 1.5, 2.0, 2.5, 3.0]); + } + + /// 8-bit affine dequant: one u32 packs four bytes 10/20/30/40 (low-order + /// first), two groups of 2 with scale/bias (2.0, +10) and (0.5, -1), so + /// `q*scale + bias` is exact. Exercises the `bits = 8` (`per_u32 = 4`) path. + #[test] + fn dequantize_affine_8bit_recovers_hand_example() { + // u32 = 0x281E_140A -> bytes [10, 20, 30, 40] low-order first. + let packed = [0x0Au8, 0x14, 0x1E, 0x28]; + let scales = [0x00u8, 0x40, 0x00, 0x38]; // f16 [2.0, 0.5] + let biases = [0x00u8, 0x49, 0x00, 0xBC]; // f16 [10.0, -1.0] + let w = dequantize_affine(&packed, &scales, &biases, 1, 1, 8, 2).unwrap(); + assert_eq!(w, vec![30.0, 50.0, 14.0, 19.0]); + } + + /// A packed buffer whose size disagrees with `[out, in_packed]` is rejected. + #[test] + fn dequantize_affine_rejects_size_mismatch() { + let packed = [0u8; 4]; + let sb = [0u8; 4]; + assert!(dequantize_affine(&packed, &sb, &sb, 2, 1, 4, 4).is_err()); + } }