Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions src/lib/mlxcel-xla/src/emitter/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,16 @@ pub enum RopeScaling {
},
}

/// MLX affine weight quantization (`config.json` `quantization`). The linear /
/// embedding `*.weight` tensors are stored packed as `U32` with companion
/// `*.scales` / `*.biases`; the loader dequantizes them to f32 as
/// `q * scale + bias` per group of `group_size` input columns.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct QuantConfig {
pub bits: usize,
pub group_size: usize,
}

#[derive(Clone, Debug)]
pub struct Config {
pub hidden: usize,
Expand All @@ -46,6 +56,10 @@ pub struct Config {
/// projection; `false` adds a separate `params['lm_head']` weight the tail
/// projects through instead (Llama-3.1-8B, larger Qwen2.5 sizes).
pub tie_word_embeddings: bool,
/// MLX affine weight quantization, if the checkpoint is quantized (`None` for
/// an unquantized bf16/f16/f32 checkpoint). The graph itself is unchanged (it
/// runs in f32); the loader dequantizes the packed weights at load.
pub quantization: Option<QuantConfig>,
}

impl Config {
Expand All @@ -69,6 +83,7 @@ impl Config {
},
qkv_bias: false,
tie_word_embeddings: true,
quantization: None,
}
}

Expand Down Expand Up @@ -120,6 +135,24 @@ impl Config {
.and_then(serde_json::Value::as_bool)
.unwrap_or(true);

// MLX affine quantization: an optional `{bits, group_size}` block. The
// loader dequantizes the packed weights; the emitted graph is unchanged.
let quantization = match v.get("quantization") {
None | Some(serde_json::Value::Null) => None,
Some(q) => {
let qu = |k: &str| -> Result<usize, String> {
q.get(k)
.and_then(serde_json::Value::as_u64)
.map(|x| x as usize)
.ok_or_else(|| format!("config.json quantization missing integer `{k}`"))
};
Some(QuantConfig {
bits: qu("bits")?,
group_size: qu("group_size")?,
})
}
};

let u = |k: &str| -> Result<usize, String> {
v.get(k)
.and_then(serde_json::Value::as_u64)
Expand Down Expand Up @@ -193,6 +226,7 @@ impl Config {
rope,
qkv_bias,
tie_word_embeddings,
quantization,
})
}

Expand Down
1 change: 1 addition & 0 deletions src/lib/mlxcel-xla/src/emitter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ mod tests {
rope: RopeScaling::Plain,
qkv_bias,
tie_word_embeddings: true,
quantization: None,
}
}

Expand Down
66 changes: 59 additions & 7 deletions src/lib/mlxcel-xla/src/iree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@
//! [`Config`], (2) emits the `prefill` / `decode_step` StableHLO graphs from that
//! config (the #451 emitter, ending in an on-device argmax) and compiles them to
//! vmfbs with the dist's `iree-compile`, cached by content hash, (3) loads the
//! weights as f32 (widening bf16 / f16, or copying f32) in the emitter's arg
//! order, from either a single-file `model.safetensors` or a sharded checkpoint
//! (via its `model.safetensors.index.json`, which is how the big untied models
//! ship), and (4) hands all of it to the
//! weights as f32 (widening bf16 / f16, copying f32, or dequantizing MLX 4 / 8-bit
//! affine weights) in the emitter's arg order, from either a single-file
//! `model.safetensors` or a sharded checkpoint (via its
//! `model.safetensors.index.json`, which is how the big untied models ship), and
//! (4) hands all of it to the
//! shim, which keeps the weights resident on the device and threads the KV cache
//! across steps. Then [`IreeLlama::prefill`] / [`IreeLlama::decode`] are token-in
//! / token-out. Emitting from config (issue #449 M3 Stage 2d) replaced the bundled
Expand All @@ -50,7 +51,7 @@ use memmap2::Mmap;
use safetensors::{Dtype, SafeTensors};

use crate::emitter::{Config, emit_decode, emit_decode_ragged, emit_prefill};
use crate::weights::{bf16_to_f32, f16_to_f32, f32_le_to_f32};
use crate::weights::{bf16_to_f32, dequantize_affine, f16_to_f32, f32_le_to_f32};

/// Prefill bucket baked into the emitted `prefill` graph (`tensor<256xi32>`,
/// == MAX_SEQ, so it covers any prompt the 256-slot KV cache holds).
Expand Down Expand Up @@ -385,6 +386,57 @@ fn load_weights(
let t = st
.tensor(name)
.map_err(|e| format!("weight {name} in {}: {e}", shard.display()))?;

// MLX affine-quantized weight: a `U32`-packed `[out, in_packed]` weight
// with same-shard `*.scales` / `*.biases`. Dequantize to `[out, in]` f32
// (the graph's dtype); the layernorms and q/k/v biases are not quantized
// and fall through to the widen path below.
if t.dtype() == Dtype::U32 {
let qc = cfg.quantization.ok_or_else(|| {
format!(
"weight {name} is U32 (quantized) but config.json has no `quantization`"
)
})?;
let prefix = name
.strip_suffix(".weight")
.ok_or_else(|| format!("quantized weight {name} does not end in `.weight`"))?;
let scales_name = format!("{prefix}.scales");
let biases_name = format!("{prefix}.biases");
let scales = st.tensor(&scales_name).map_err(|e| {
format!("{scales_name} (for {name}) in {}: {e}", shard.display())
})?;
let biases = st.tensor(&biases_name).map_err(|e| {
format!("{biases_name} (for {name}) in {}: {e}", shard.display())
})?;
if scales.dtype() != Dtype::F16 || biases.dtype() != Dtype::F16 {
return Err(format!(
"{prefix} scales/biases dtype {:?}/{:?}, expected F16",
scales.dtype(),
biases.dtype()
));
}
let shape = t.shape();
if shape.len() != 2 {
return Err(format!("quantized weight {name} rank {} != 2", shape.len()));
}
let (out, in_packed) = (shape[0], shape[1]);
let in_ = in_packed * (32 / qc.bits);
bufs[i] = dequantize_affine(
t.data(),
scales.data(),
biases.data(),
out,
in_packed,
qc.bits,
qc.group_size,
)
.map_err(|e| format!("dequantize {name}: {e}"))?;
ranks[i] = 2;
dims[i * 4] = out as i64;
dims[i * 4 + 1] = in_ as i64;
continue;
}

// Widen to f32 (the graph's weight dtype). bf16 and f16 are the common
// checkpoint dtypes; f32 is a passthrough. The widening is exact for
// all three, so it matches HF's f32 reference.
Expand All @@ -394,8 +446,8 @@ fn load_weights(
Dtype::F32 => f32_le_to_f32(t.data()),
other => {
return Err(format!(
"weight {name} dtype {other:?}, expected BF16/F16/F32 \
(a quantized checkpoint must be dequantized first)"
"weight {name} dtype {other:?}, expected BF16/F16/F32 or \
MLX-quantized U32"
));
}
};
Expand Down
102 changes: 102 additions & 0 deletions src/lib/mlxcel-xla/src/weights.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,74 @@ pub(crate) fn f32_le_to_f32(bytes: &[u8]) -> Vec<f32> {
.collect()
}

/// Dequantize one MLX affine-quantized weight to row-major `[out, in]` f32.
///
/// `packed` is the row-major `[out, in_packed]` u32 weight (little-endian bytes,
/// `in_packed = in * bits / 32`); `scales` / `biases` are the row-major
/// `[out, in/group_size]` f16 buffers. Each weight is recovered as
/// `w[o,i] = q[o,i] * scale[o, i/group_size] + bias[o, i/group_size]`, where `q`
/// is the `bits`-wide value unpacked low-order-first from `packed[o, i/(32/bits)]`
/// (the MLX affine layout). The graph runs in f32, so the packed weights are
/// widened here once at load.
pub(crate) fn dequantize_affine(
packed: &[u8],
scales: &[u8],
biases: &[u8],
out: usize,
in_packed: usize,
bits: usize,
group_size: usize,
) -> Result<Vec<f32>, String> {
if !(bits == 4 || bits == 8) {
return Err(format!(
"unsupported quantization bits {bits} (expected 4 or 8)"
));
}
let per_u32 = 32 / bits; // values packed per u32
let in_ = in_packed * per_u32;
if group_size == 0 || !in_.is_multiple_of(group_size) {
return Err(format!(
"quantization group_size {group_size} does not divide in dimension {in_}"
));
}
let n_groups = in_ / group_size;
if packed.len() != out * in_packed * 4 {
return Err(format!(
"packed weight is {} bytes, expected {} ([{out}, {in_packed}] u32)",
packed.len(),
out * in_packed * 4
));
}
let scales = f16_to_f32(scales);
let biases = f16_to_f32(biases);
if scales.len() != out * n_groups || biases.len() != out * n_groups {
return Err(format!(
"scales/biases have {}/{} elements, expected {} ([{out}, {n_groups}])",
scales.len(),
biases.len(),
out * n_groups
));
}
let mask: u32 = (1u32 << bits) - 1;
let mut w = vec![0f32; out * in_];
for o in 0..out {
let row = &packed[o * in_packed * 4..(o + 1) * in_packed * 4];
let grow = o * n_groups;
let wrow = o * in_;
for p in 0..in_packed {
let u =
u32::from_le_bytes([row[p * 4], row[p * 4 + 1], row[p * 4 + 2], row[p * 4 + 3]]);
for j in 0..per_u32 {
let i = p * per_u32 + j;
let q = ((u >> (bits * j)) & mask) as f32;
let g = i / group_size;
w[wrow + i] = q * scales[grow + g] + biases[grow + g];
}
}
}
Ok(w)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down Expand Up @@ -115,4 +183,38 @@ mod tests {
let bytes = 1.5f32.to_le_bytes();
assert_eq!(f32_le_to_f32(&bytes), vec![1.5]);
}

/// 4-bit affine dequant on a hand-built row: one u32 packs eight nibbles
/// 1..=8 (low-order first), two groups of 4 with scale/bias (2.0, +10) and
/// (0.5, -1), so `q*scale + bias` is exact.
#[test]
fn dequantize_affine_recovers_hand_example() {
// u32 = 0x8765_4321 -> nibbles [1,2,3,4,5,6,7,8] low-order first.
let packed = [0x21u8, 0x43, 0x65, 0x87];
let scales = [0x00u8, 0x40, 0x00, 0x38]; // f16 [2.0, 0.5]
let biases = [0x00u8, 0x49, 0x00, 0xBC]; // f16 [10.0, -1.0]
let w = dequantize_affine(&packed, &scales, &biases, 1, 1, 4, 4).unwrap();
assert_eq!(w, vec![12.0, 14.0, 16.0, 18.0, 1.5, 2.0, 2.5, 3.0]);
}

/// 8-bit affine dequant: one u32 packs four bytes 10/20/30/40 (low-order
/// first), two groups of 2 with scale/bias (2.0, +10) and (0.5, -1), so
/// `q*scale + bias` is exact. Exercises the `bits = 8` (`per_u32 = 4`) path.
#[test]
fn dequantize_affine_8bit_recovers_hand_example() {
// u32 = 0x281E_140A -> bytes [10, 20, 30, 40] low-order first.
let packed = [0x0Au8, 0x14, 0x1E, 0x28];
let scales = [0x00u8, 0x40, 0x00, 0x38]; // f16 [2.0, 0.5]
let biases = [0x00u8, 0x49, 0x00, 0xBC]; // f16 [10.0, -1.0]
let w = dequantize_affine(&packed, &scales, &biases, 1, 1, 8, 2).unwrap();
assert_eq!(w, vec![30.0, 50.0, 14.0, 19.0]);
}

/// A packed buffer whose size disagrees with `[out, in_packed]` is rejected.
#[test]
fn dequantize_affine_rejects_size_mismatch() {
let packed = [0u8; 4];
let sb = [0u8; 4];
assert!(dequantize_affine(&packed, &sb, &sb, 2, 1, 4, 4).is_err());
}
}