Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 23 additions & 14 deletions src/lib/mlxcel-xla/src/emitter/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,10 @@
//! [`Config::from_json`] reads the same shape from a checkpoint's `config.json`
//! (issue #449 M3 Stage 2d). Stage A covered the Llama architecture (llama3 RoPE,
//! no attention bias); Stage B adds Qwen2 (plain RoPE + QKV bias), so the config
//! carries the two architecture switches the emitter branches on: the RoPE kind
//! and whether q/k/v projections have a bias.
//! carries the architecture switches the emitter branches on: the RoPE kind,
//! whether q/k/v projections have a bias, and whether the LM head is tied to the
//! token embedding (tied) or a separate `lm_head.weight` (untied, e.g.
//! Llama-3.1-8B and the larger Qwen2.5 checkpoints).

/// How the RoPE inverse-frequency table is computed. Both kinds share the
/// `outer(pos, inv_freq)` table build (see [`rope`](super::rope)); they differ
Expand Down Expand Up @@ -39,6 +41,11 @@ pub struct Config {
/// q/k/v projections carry a bias (Qwen2). `o_proj` never does, and the MLP
/// projections never do, so this single switch covers the architecture delta.
pub qkv_bias: bool,
/// The LM head shares the token-embedding matrix (HF `tie_word_embeddings`).
/// `true` (Llama-3.2-1B, Qwen2.5-0.5B) reuses `params['embed']` for the final
/// projection; `false` adds a separate `params['lm_head']` weight the tail
/// projects through instead (Llama-3.1-8B, larger Qwen2.5 sizes).
pub tie_word_embeddings: bool,
}

impl Config {
Expand All @@ -61,16 +68,18 @@ impl Config {
orig_ctx: 8192,
},
qkv_bias: false,
tie_word_embeddings: true,
}
}

/// Build a [`Config`] from a model's `config.json` text.
///
/// Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied
/// embeddings). Llama uses llama3 RoPE scaling and no attention bias; Qwen2
/// uses plain RoPE and a q/k/v projection bias. Configs the emitter cannot yet
/// reproduce are rejected with a clear error rather than silently mis-emitted:
/// an unsupported `model_type`, untied embeddings, a `llama` checkpoint with
/// Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied or
/// untied embeddings). Llama uses llama3 RoPE scaling and no attention bias;
/// Qwen2 uses plain RoPE and a q/k/v projection bias; either may tie its LM
/// head to the token embedding or carry a separate `lm_head.weight`. Configs
/// the emitter cannot yet reproduce are rejected with a clear error rather than
/// silently mis-emitted: an unsupported `model_type`, a `llama` checkpoint with
/// `attention_bias`, or a `rope_scaling` whose `rope_type` is not `llama3`.
pub fn from_json_str(s: &str) -> Result<Self, String> {
let v: serde_json::Value =
Expand Down Expand Up @@ -103,14 +112,13 @@ impl Config {
}
};

if v.get("tie_word_embeddings")
// Tied (share `embed` for the head) vs untied (separate `lm_head.weight`).
// HF `PretrainedConfig` defaults this to `true`, so an absent field means
// tied; the emitter and the weight loader branch on it.
let tie_word_embeddings = v
.get("tie_word_embeddings")
.and_then(serde_json::Value::as_bool)
!= Some(true)
{
return Err("the OpenXLA emitter assumes tied word embeddings; \
config.json tie_word_embeddings != true (untied LM head is a follow-up)"
.to_string());
}
.unwrap_or(true);

let u = |k: &str| -> Result<usize, String> {
v.get(k)
Expand Down Expand Up @@ -184,6 +192,7 @@ impl Config {
vocab: u("vocab_size")?,
rope,
qkv_bias,
tie_word_embeddings,
})
}

Expand Down
118 changes: 99 additions & 19 deletions src/lib/mlxcel-xla/src/emitter/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
//! `config.json` at load, instead of being pinned to the bundled Llama-3.2-1B
//! `.mlir` assets.
//!
//! Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied
//! embeddings). The `Config` is parameterized by dimensions, so any checkpoint of
//! a supported architecture (any size) emits correctly. The two architecture
//! switches the emitter branches on are the RoPE kind (llama3 scaling for Llama,
//! plain for Qwen2) and whether the q/k/v projections carry a bias (Qwen2, Stage
//! B); other architectures (e.g. Gemma GeGLU / softcap) are a follow-up that
//! extends the emitter and `from_json`.
//! Scope: the Llama and Qwen2 architectures (RMSNorm, SwiGLU MLP, GQA, tied or
//! untied embeddings). The `Config` is parameterized by dimensions, so any
//! checkpoint of a supported architecture (any size) emits correctly. The
//! architecture switches the emitter branches on are the RoPE kind (llama3 scaling
//! for Llama, plain for Qwen2), whether the q/k/v projections carry a bias (Qwen2,
//! Stage B), and whether the LM head is tied to the token embedding or a separate
//! `lm_head.weight` (untied, e.g. Llama-3.1-8B); other architectures (e.g. Gemma
//! GeGLU / softcap) are a follow-up that extends the emitter and `from_json`.
//!
//! Pure Rust (no IREE), so it compiles and is unit-tested without the `iree`
//! feature; only the IREE engine consumes it. The bundled `.mlir` assets remain
Expand Down Expand Up @@ -81,6 +82,7 @@ mod tests {
vocab: 10,
rope: RopeScaling::Plain,
qkv_bias,
tie_word_embeddings: true,
}
}

Expand Down Expand Up @@ -139,9 +141,9 @@ mod tests {
assert_eq!(c.eps, 1e-6);
}

/// A non-Llama/Qwen2 architecture, untied embeddings, and an unsupported
/// `rope_scaling` are each rejected with a clear message rather than
/// mis-emitted.
/// A non-Llama/Qwen2 architecture and an unsupported `rope_scaling` are each
/// rejected with a clear message rather than mis-emitted. (Untied embeddings
/// are no longer rejected; see `from_json_accepts_untied_embeddings`.)
#[test]
fn from_json_rejects_unsupported_configs() {
let gemma = r#"{"model_type":"gemma2","tie_word_embeddings":true,"hidden_size":8,
Expand All @@ -153,15 +155,6 @@ mod tests {
.contains("model_type")
);

let untied = r#"{"model_type":"qwen2","tie_word_embeddings":false,"hidden_size":8,
"num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2,
"num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#;
assert!(
Config::from_json_str(untied)
.unwrap_err()
.contains("tie_word_embeddings")
);

let yarn = r#"{"model_type":"qwen2","tie_word_embeddings":true,"hidden_size":8,
"num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2,
"num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10,
Expand All @@ -173,6 +166,31 @@ mod tests {
);
}

/// Untied embeddings are supported (issue #449 M3 Stage 2d): `from_json` reads
/// `tie_word_embeddings = false`, and an absent field defaults to tied (the HF
/// `PretrainedConfig` default).
#[test]
fn from_json_accepts_untied_embeddings() {
let untied = r#"{"model_type":"qwen2","tie_word_embeddings":false,"hidden_size":8,
"num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2,
"num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#;
let c = Config::from_json_str(untied).expect("untied qwen2 parses");
assert!(
!c.tie_word_embeddings,
"tie_word_embeddings=false -> untied"
);

let absent = r#"{"model_type":"qwen2","hidden_size":8,
"num_attention_heads":2,"intermediate_size":16,"num_hidden_layers":2,
"num_key_value_heads":1,"rms_norm_eps":1e-6,"rope_theta":1e4,"vocab_size":10}"#;
assert!(
Config::from_json_str(absent)
.expect("absent field parses")
.tie_word_embeddings,
"absent tie_word_embeddings defaults to tied"
);
}

/// Turning on `qkv_bias` adds exactly the three q/k/v projection biases per
/// layer to every graph kind, and the adds that consume them: the single-token
/// decode adds one `stablehlo.add` per bias; the seq graphs (prefill / batched
Expand Down Expand Up @@ -271,6 +289,68 @@ mod tests {
assert!(!mlir.contains("['bv']"));
}

/// Untied embeddings add exactly one weight arg — the `[V, H]`
/// `params['lm_head']` — to every graph kind, positioned right after
/// `final_norm` and before the layers (arg 2), and the final projection
/// consumes it. A `tie_word_embeddings = true` config emits no such arg, so the
/// arg counts differ by exactly one and a tied graph never names `lm_head` (the
/// guard that keeps every shipped tied checkpoint byte-identical). Mirrors
/// `weight_names` in `iree.rs`, which adds `lm_head.weight` in the same slot.
#[test]
fn untied_adds_one_lm_head_arg_after_final_norm() {
let tied = qwen_like(true);
let mut untied = tied.clone();
untied.tie_word_embeddings = false;

let graphs = [
(
emit_decode(&untied, false),
emit_decode(&tied, false),
"decode",
),
(
emit_prefill(&untied, false),
emit_prefill(&tied, false),
"prefill",
),
(
super::model::emit_decode_batched(&untied, 4, false),
super::model::emit_decode_batched(&tied, 4, false),
"batched",
),
(
emit_decode_ragged(&untied, 4, false),
emit_decode_ragged(&tied, 4, false),
"ragged",
),
];
for (g_untied, g_tied, name) in graphs {
assert_eq!(
arg_count(&g_untied) - arg_count(&g_tied),
1,
"{name}: untied adds exactly the lm_head arg"
);
assert_eq!(
occurs(&g_untied, "params['lm_head']"),
1,
"{name}: lm_head declared exactly once (signature only)"
);
assert_eq!(
occurs(&g_tied, "params['lm_head']"),
0,
"{name}: tied graph never names lm_head"
);
// arg order: final_norm (arg 1) < lm_head (arg 2) < layer 0's weights.
let fnorm = g_untied.find("params['final_norm']").unwrap();
let lm = g_untied.find("params['lm_head']").unwrap();
let l0 = g_untied.find("params['layers'][0]").unwrap();
assert!(
fnorm < lm && lm < l0,
"{name}: expected final_norm < lm_head < layer0 arg order"
);
}
}

/// Plain RoPE base frequencies are the textbook `1 / theta^(2i/head_dim)`
/// (Qwen2), distinct from the llama3-scaled table.
#[test]
Expand Down
Loading