diff --git a/Cargo.lock b/Cargo.lock index 680c34ac..e399c592 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -98,9 +98,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.102" +version = "1.0.103" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" +checksum = "2a4385e2e34eb35d6b3efe798b9eb88096925d87726c0798709bf56d9ed84af3" [[package]] name = "arrayref" diff --git a/src/lib/mlxcel-xla/src/emitter/config.rs b/src/lib/mlxcel-xla/src/emitter/config.rs index fa63d132..2679bc13 100644 --- a/src/lib/mlxcel-xla/src/emitter/config.rs +++ b/src/lib/mlxcel-xla/src/emitter/config.rs @@ -3,8 +3,10 @@ //! [`Config::from_json`] reads the same shape from a checkpoint's `config.json` //! (issue #449 M3 Stage 2d). Stage A covered the Llama architecture (llama3 RoPE, //! no attention bias); Stage B adds Qwen2 (plain RoPE + QKV bias), so the config -//! carries the two architecture switches the emitter branches on: the RoPE kind -//! and whether q/k/v projections have a bias. +//! carries the architecture switches the emitter branches on: the RoPE kind, +//! whether q/k/v projections have a bias, and whether the LM head is tied to the +//! token embedding (tied) or a separate `lm_head.weight` (untied, e.g. +//! Llama-3.1-8B and the larger Qwen2.5 checkpoints). /// How the RoPE inverse-frequency table is computed. Both kinds share the /// `outer(pos, inv_freq)` table build (see [`rope`](super::rope)); they differ @@ -39,6 +41,11 @@ pub struct Config { /// q/k/v projections carry a bias (Qwen2). `o_proj` never does, and the MLP /// projections never do, so this single switch covers the architecture delta. pub qkv_bias: bool, + /// The LM head shares the token-embedding matrix (HF `tie_word_embeddings`). + /// `true` (Llama-3.2-1B, Qwen2.5-0.5B) reuses `params['embed']` for the final + /// projection; `false` adds a separate `params['lm_head']` weight the tail + /// projects through instead (Llama-3.1-8B, larger Qwen2.5 sizes). + pub tie_word_embeddings: bool, } impl Config { @@ -61,16 +68,18 @@ impl Config { orig_ctx: 8192, }, qkv_bias: false, + tie_word_embeddings: true, } } /// Build a [`Config`] from a model's `config.json` text. /// - /// Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied - /// embeddings). Llama uses llama3 RoPE scaling and no attention bias; Qwen2 - /// uses plain RoPE and a q/k/v projection bias. Configs the emitter cannot yet - /// reproduce are rejected with a clear error rather than silently mis-emitted: - /// an unsupported `model_type`, untied embeddings, a `llama` checkpoint with + /// Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied or + /// untied embeddings). Llama uses llama3 RoPE scaling and no attention bias; + /// Qwen2 uses plain RoPE and a q/k/v projection bias; either may tie its LM + /// head to the token embedding or carry a separate `lm_head.weight`. Configs + /// the emitter cannot yet reproduce are rejected with a clear error rather than + /// silently mis-emitted: an unsupported `model_type`, a `llama` checkpoint with /// `attention_bias`, or a `rope_scaling` whose `rope_type` is not `llama3`. pub fn from_json_str(s: &str) -> Result { let v: serde_json::Value = @@ -103,14 +112,13 @@ impl Config { } }; - if v.get("tie_word_embeddings") + // Tied (share `embed` for the head) vs untied (separate `lm_head.weight`). + // HF `PretrainedConfig` defaults this to `true`, so an absent field means + // tied; the emitter and the weight loader branch on it. + let tie_word_embeddings = v + .get("tie_word_embeddings") .and_then(serde_json::Value::as_bool) - != Some(true) - { - return Err("the OpenXLA emitter assumes tied word embeddings; \ - config.json tie_word_embeddings != true (untied LM head is a follow-up)" - .to_string()); - } + .unwrap_or(true); let u = |k: &str| -> Result { v.get(k) @@ -184,6 +192,7 @@ impl Config { vocab: u("vocab_size")?, rope, qkv_bias, + tie_word_embeddings, }) } diff --git a/src/lib/mlxcel-xla/src/emitter/mod.rs b/src/lib/mlxcel-xla/src/emitter/mod.rs index d5c7f6b6..21ea69b7 100644 --- a/src/lib/mlxcel-xla/src/emitter/mod.rs +++ b/src/lib/mlxcel-xla/src/emitter/mod.rs @@ -19,13 +19,14 @@ //! `config.json` at load, instead of being pinned to the bundled Llama-3.2-1B //! `.mlir` assets. //! -//! Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied -//! embeddings). The `Config` is parameterized by dimensions, so any checkpoint of -//! a supported architecture (any size) emits correctly. The two architecture -//! switches the emitter branches on are the RoPE kind (llama3 scaling for Llama, -//! plain for Qwen2) and whether the q/k/v projections carry a bias (Qwen2, Stage -//! B); other architectures (e.g. Gemma GeGLU / softcap) are a follow-up that -//! extends the emitter and `from_json`. +//! Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied or +//! untied embeddings). The `Config` is parameterized by dimensions, so any +//! checkpoint of a supported architecture (any size) emits correctly. The +//! architecture switches the emitter branches on are the RoPE kind (llama3 scaling +//! for Llama, plain for Qwen2), whether the q/k/v projections carry a bias (Qwen2, +//! Stage B), and whether the LM head is tied to the token embedding or a separate +//! `lm_head.weight` (untied, e.g. Llama-3.1-8B); other architectures (e.g. Gemma +//! GeGLU / softcap) are a follow-up that extends the emitter and `from_json`. //! //! Pure Rust (no IREE), so it compiles and is unit-tested without the `iree` //! feature; only the IREE engine consumes it. The bundled `.mlir` assets remain @@ -81,6 +82,7 @@ mod tests { vocab: 10, rope: RopeScaling::Plain, qkv_bias, + tie_word_embeddings: true, } } @@ -139,9 +141,9 @@ mod tests { assert_eq!(c.eps, 1e-6); } - /// A non-Llama/Qwen2 architecture, untied embeddings, and an unsupported - /// `rope_scaling` are each rejected with a clear message rather than - /// mis-emitted. + /// A non-Llama/Qwen2 architecture and an unsupported `rope_scaling` are each + /// rejected with a clear message rather than mis-emitted. (Untied embeddings + /// are no longer rejected; see `from_json_accepts_untied_embeddings`.) #[test] fn from_json_rejects_unsupported_configs() { let gemma = r#"{"model_type":"gemma2","tie_word_embeddings":true,"hidden_size":8, @@ -153,15 +155,6 @@ mod tests { .contains("model_type") ); - let untied = r#"{"model_type":"qwen2","tie_word_embeddings":false,"hidden_size":8, - "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, - "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#; - assert!( - Config::from_json_str(untied) - .unwrap_err() - .contains("tie_word_embeddings") - ); - let yarn = r#"{"model_type":"qwen2","tie_word_embeddings":true,"hidden_size":8, "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10, @@ -173,6 +166,31 @@ mod tests { ); } + /// Untied embeddings are supported (issue #449 M3 Stage 2d): `from_json` reads + /// `tie_word_embeddings = false`, and an absent field defaults to tied (the HF + /// `PretrainedConfig` default). + #[test] + fn from_json_accepts_untied_embeddings() { + let untied = r#"{"model_type":"qwen2","tie_word_embeddings":false,"hidden_size":8, + "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, + "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#; + let c = Config::from_json_str(untied).expect("untied qwen2 parses"); + assert!( + !c.tie_word_embeddings, + "tie_word_embeddings=false -> untied" + ); + + let absent = r#"{"model_type":"qwen2","hidden_size":8, + "num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2, + "num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#; + assert!( + Config::from_json_str(absent) + .expect("absent field parses") + .tie_word_embeddings, + "absent tie_word_embeddings defaults to tied" + ); + } + /// Turning on `qkv_bias` adds exactly the three q/k/v projection biases per /// layer to every graph kind, and the adds that consume them: the single-token /// decode adds one `stablehlo.add` per bias; the seq graphs (prefill / batched @@ -271,6 +289,68 @@ mod tests { assert!(!mlir.contains("['bv']")); } + /// Untied embeddings add exactly one weight arg — the `[V, H]` + /// `params['lm_head']` — to every graph kind, positioned right after + /// `final_norm` and before the layers (arg 2), and the final projection + /// consumes it. A `tie_word_embeddings = true` config emits no such arg, so the + /// arg counts differ by exactly one and a tied graph never names `lm_head` (the + /// guard that keeps every shipped tied checkpoint byte-identical). Mirrors + /// `weight_names` in `iree.rs`, which adds `lm_head.weight` in the same slot. + #[test] + fn untied_adds_one_lm_head_arg_after_final_norm() { + let tied = qwen_like(true); + let mut untied = tied.clone(); + untied.tie_word_embeddings = false; + + let graphs = [ + ( + emit_decode(&untied, false), + emit_decode(&tied, false), + "decode", + ), + ( + emit_prefill(&untied, false), + emit_prefill(&tied, false), + "prefill", + ), + ( + super::model::emit_decode_batched(&untied, 4, false), + super::model::emit_decode_batched(&tied, 4, false), + "batched", + ), + ( + emit_decode_ragged(&untied, 4, false), + emit_decode_ragged(&tied, 4, false), + "ragged", + ), + ]; + for (g_untied, g_tied, name) in graphs { + assert_eq!( + arg_count(&g_untied) - arg_count(&g_tied), + 1, + "{name}: untied adds exactly the lm_head arg" + ); + assert_eq!( + occurs(&g_untied, "params['lm_head']"), + 1, + "{name}: lm_head declared exactly once (signature only)" + ); + assert_eq!( + occurs(&g_tied, "params['lm_head']"), + 0, + "{name}: tied graph never names lm_head" + ); + // arg order: final_norm (arg 1) < lm_head (arg 2) < layer 0's weights. + let fnorm = g_untied.find("params['final_norm']").unwrap(); + let lm = g_untied.find("params['lm_head']").unwrap(); + let l0 = g_untied.find("params['layers'][0]").unwrap(); + assert!( + fnorm < lm && lm < l0, + "{name}: expected final_norm < lm_head < layer0 arg order" + ); + } + } + /// Plain RoPE base frequencies are the textbook `1 / theta^(2i/head_dim)` /// (Qwen2), distinct from the llama3-scaled table. #[test] diff --git a/src/lib/mlxcel-xla/src/emitter/model.rs b/src/lib/mlxcel-xla/src/emitter/model.rs index 9fa321e2..8fa690ba 100644 --- a/src/lib/mlxcel-xla/src/emitter/model.rs +++ b/src/lib/mlxcel-xla/src/emitter/model.rs @@ -8,7 +8,11 @@ //! (alphabetical within each layer), each carrying its pytree-path loc so the //! arg-to-weight mapping is self-documenting and reuses the JAX weight glue. For //! a `qkv_bias` architecture (Qwen2) the per-layer q/k/v projection biases follow -//! the layer's weights (see [`take_layer_weights`]). +//! the layer's weights (see [`take_layer_weights`]). For an untied checkpoint +//! (`tie_word_embeddings = false`) a separate `params['lm_head']` weight follows +//! `final_norm` and feeds the final logits projection in place of the shared +//! `embed` matrix (see [`take_lm_head`]); a tied checkpoint emits no such arg and +//! is byte-identical to before. use super::builder::{Builder, Ty, Val}; use super::config::Config; @@ -45,6 +49,8 @@ struct LayerW { struct Args { embed: Val, final_norm: Val, + /// Untied LM head (`None` when tied; the tail then reuses `embed`). + lm_head: Option, layers: Vec, token: Val, pos: Val, @@ -69,6 +75,33 @@ fn take_arg(decls: &mut Vec, idx: &mut usize, ty: Ty, loc: String) -> V val } +/// Take the untied LM head weight `params['lm_head']` (`[V, H]`), or `None` for a +/// tied checkpoint (which reuses `embed` for the final projection). Called right +/// after `final_norm` and before the layers, so the weight arg order is embed, +/// final_norm, [lm_head when untied], layers..., matching `weight_names` in +/// `iree.rs`. For a tied model nothing is emitted, so the graph stays byte- +/// identical (the guard that keeps every tied checkpoint unchanged). +fn take_lm_head(decls: &mut Vec, idx: &mut usize, c: &Config) -> Option { + if c.tie_word_embeddings { + None + } else { + Some(take_arg( + decls, + idx, + Ty::f32(vec![c.vocab, c.hidden]), + "params['lm_head']".into(), + )) + } +} + +/// The weight the final logits projection multiplies by: the dedicated `lm_head` +/// for an untied checkpoint, else the tied token-embedding matrix. Both are +/// `[V, H]` (`linear` computes `x @ W^T`), so the tail is identical apart from +/// which buffer it reads. +fn head_weight<'a>(embed: &'a Val, lm_head: &'a Option) -> &'a Val { + lm_head.as_ref().unwrap_or(embed) +} + /// Append layer `li`'s weights (and, for `qkv_bias`, its q/k/v biases) in the one /// canonical order every graph kind shares, so the emitted arg order matches /// `weight_names` in `iree.rs` exactly. JAX-alphabetical weights (down, gate, @@ -156,6 +189,7 @@ fn build_arg_schema(c: &Config) -> (Vec, Args) { Ty::f32(vec![h]), "params['final_norm']".into(), ); + let lm_head = take_lm_head(&mut decls, &mut idx, c); let mut layers = Vec::with_capacity(c.n_layers); for li in 0..c.n_layers { @@ -183,6 +217,7 @@ fn build_arg_schema(c: &Config) -> (Vec, Args) { Args { embed, final_norm, + lm_head, layers, token, pos, @@ -385,9 +420,10 @@ pub fn emit_decode(c: &Config, sample: bool) -> String { x = b.add(&x, &down); } - // --- tail: final norm + tied LM head, then optional on-device argmax --- + // --- tail: final norm + LM head (tied embed or untied lm_head), then + // optional on-device argmax --- let xf = rms_norm(&mut b, &x, &a.final_norm, &k, h); - let logits = b.linear(&xf, &a.embed); // [V] + let logits = b.linear(&xf, head_weight(&a.embed, &a.lm_head)); // [V] let (out_val, out_ty) = if sample { let tok = b.argmax(&logits); (tok.name, Ty::scalar("i32").render()) @@ -426,6 +462,7 @@ pub fn emit_decode(c: &Config, sample: bool) -> String { struct BatchedArgs { embed: Val, final_norm: Val, + lm_head: Option, layers: Vec, token: Val, // [B] i32 pos: Val, // scalar i32 (shared across the batch) @@ -453,6 +490,7 @@ fn build_batched_arg_schema(c: &Config, bsz: usize) -> (Vec, BatchedArg Ty::f32(vec![h]), "params['final_norm']".into(), ); + let lm_head = take_lm_head(&mut decls, &mut idx, c); let mut layers = Vec::with_capacity(c.n_layers); for li in 0..c.n_layers { @@ -485,6 +523,7 @@ fn build_batched_arg_schema(c: &Config, bsz: usize) -> (Vec, BatchedArg BatchedArgs { embed, final_norm, + lm_head, layers, token, pos, @@ -655,9 +694,10 @@ pub fn emit_decode_batched(c: &Config, bsz: usize, sample: bool) -> String { x = b.add(&x, &down); } - // --- tail: final norm + tied LM head -> [B, V], optional per-row argmax --- + // --- tail: final norm + LM head (tied embed or untied lm_head) -> [B, V], + // optional per-row argmax --- let xf = rms_norm_seq(&mut b, &x, &a.final_norm, &k, bsz, h); // [B, H] - let logits = b.linear_seq(&xf, &a.embed); // [B, V] + let logits = b.linear_seq(&xf, head_weight(&a.embed, &a.lm_head)); // [B, V] let (out_val, out_ty) = if sample { let tok = b.argmax_batched(&logits); (tok.name, Ty::new(vec![bsz], "i32").render()) @@ -696,6 +736,7 @@ pub fn emit_decode_batched(c: &Config, bsz: usize, sample: bool) -> String { struct RaggedArgs { embed: Val, final_norm: Val, + lm_head: Option, layers: Vec, token: Val, // [B] i32 pos: Val, // [B] i32 (per row) @@ -723,6 +764,7 @@ fn build_ragged_arg_schema(c: &Config, bsz: usize) -> (Vec, RaggedArgs) Ty::f32(vec![h]), "params['final_norm']".into(), ); + let lm_head = take_lm_head(&mut decls, &mut idx, c); let mut layers = Vec::with_capacity(c.n_layers); for li in 0..c.n_layers { @@ -765,6 +807,7 @@ fn build_ragged_arg_schema(c: &Config, bsz: usize) -> (Vec, RaggedArgs) RaggedArgs { embed, final_norm, + lm_head, layers, token, pos, @@ -942,7 +985,7 @@ pub fn emit_decode_ragged(c: &Config, bsz: usize, sample: bool) -> String { } let xf = rms_norm_seq(&mut b, &x, &a.final_norm, &k, bsz, h); - let logits = b.linear_seq(&xf, &a.embed); // [B, V] + let logits = b.linear_seq(&xf, head_weight(&a.embed, &a.lm_head)); // [B, V] let (out_val, out_ty) = if sample { let tok = b.argmax_batched(&logits); (tok.name, Ty::new(vec![bsz], "i32").render()) @@ -981,6 +1024,7 @@ pub fn emit_decode_ragged(c: &Config, bsz: usize, sample: bool) -> String { struct PrefillArgs { embed: Val, final_norm: Val, + lm_head: Option, layers: Vec, tokens: Val, positions: Val, @@ -1006,6 +1050,7 @@ fn build_prefill_arg_schema(c: &Config, lp: usize) -> (Vec, PrefillArgs Ty::f32(vec![h]), "params['final_norm']".into(), ); + let lm_head = take_lm_head(&mut decls, &mut idx, c); let mut layers = Vec::with_capacity(c.n_layers); for li in 0..c.n_layers { @@ -1031,6 +1076,7 @@ fn build_prefill_arg_schema(c: &Config, lp: usize) -> (Vec, PrefillArgs PrefillArgs { embed, final_norm, + lm_head, layers, tokens, positions, @@ -1177,13 +1223,14 @@ pub fn emit_prefill(c: &Config, sample: bool) -> String { x = b.add(&x, &down); } - // --- tail: final norm, take the row at real_len-1, tied LM head --- + // --- tail: final norm, take the row at real_len-1, LM head (tied embed or + // untied lm_head) --- let xf = rms_norm_seq(&mut b, &x, &a.final_norm, &k, lp, h); // [Lp, H] let one_i = b.const_i32(1); let last_idx = b.subtract(&a.real_len, &one_i); // real_len - 1 let last_row = b.dynamic_slice(&xf, &[&last_idx, &k.c0], vec![1, h]); // [1, H] let last = b.reshape(&last_row, vec![h]); // [H] - let logits = b.linear(&last, &a.embed); // [V] + let logits = b.linear(&last, head_weight(&a.embed, &a.lm_head)); // [V] let (out_val, out_ty) = if sample { let tok = b.argmax(&logits); (tok.name, Ty::scalar("i32").render()) diff --git a/src/lib/mlxcel-xla/src/iree.rs b/src/lib/mlxcel-xla/src/iree.rs index 2dca7a8c..372f1c11 100644 --- a/src/lib/mlxcel-xla/src/iree.rs +++ b/src/lib/mlxcel-xla/src/iree.rs @@ -27,6 +27,9 @@ //! Llama-3.2-1B `.mlir` assets, so any checkpoint of a supported architecture //! loads: Llama (any size) and Qwen2 (plain RoPE + q/k/v bias; Stage B), the //! latter adding its bias tensors to `weight_names` to match the emitted graph. +//! An untied checkpoint (`tie_word_embeddings = false`, e.g. Llama-3.1-8B and the +//! larger Qwen2.5 sizes) adds its `lm_head.weight` to `weight_names`, matching the +//! separate `params['lm_head']` arg the emitter takes for the final projection. //! //! Proven token-exact against the HF temp-0 reference in //! `spike/openxla/artifacts/results.json` before being vendored from the @@ -112,16 +115,23 @@ unsafe extern "C" { /// load (any `b_max` is emittable; the worker selects from this set). pub(crate) const RAGGED_B_VALUES: &[usize] = &[4, 8]; -/// The weight names in the emitter's exact arg order: embed, final_norm, then per -/// layer down, gate, in_ln, post_ln, up, wk, wo, wq, wv, and — for a `qkv_bias` -/// architecture (Qwen2) — the k/q/v projection biases. The layer count and the -/// presence of biases come from the model config so the order matches the emitted -/// graph's args (`take_layer_weights` in `emitter/model.rs`). +/// The weight names in the emitter's exact arg order: embed, final_norm, then — +/// for an untied checkpoint (`tie_word_embeddings = false`) — `lm_head.weight`, +/// then per layer down, gate, in_ln, post_ln, up, wk, wo, wq, wv, and — for a +/// `qkv_bias` architecture (Qwen2) — the k/q/v projection biases. The layer count, +/// the untied head, and the presence of biases come from the model config so the +/// order matches the emitted graph's args (`take_lm_head` / `take_layer_weights` +/// in `emitter/model.rs`). fn weight_names(cfg: &Config) -> Vec { let mut names = vec![ "model.embed_tokens.weight".to_string(), "model.norm.weight".to_string(), ]; + // Untied LM head: a separate `lm_head.weight` follows `final_norm`, matching + // the `params['lm_head']` arg the emitter takes in the same position. + if !cfg.tie_word_embeddings { + names.push("lm_head.weight".to_string()); + } for i in 0..cfg.n_layers { let p = format!("model.layers.{i}."); for suf in [