feat: serve untied LM heads on the OpenXLA path (separate lm_head weight) (#449 M3 Stage 2d)#482
Merged
Merged
Conversation
…ght) (#449 M3 Stage 2d) `Config::from_json` hard-rejected `tie_word_embeddings = false`, so the OpenXLA backend could only load checkpoints that tie the LM head to the token embedding (Llama-3.2-1B, Qwen2.5-0.5B). That blocked every untied model, including Llama-3.1-8B and the larger Qwen2.5 sizes, whose head is a separate `lm_head.weight`. Add an untied path that the existing tied checkpoints are byte-identical across: - `Config` gains `tie_word_embeddings`; `from_json` reads it (absent defaults to tied, the HF `PretrainedConfig` default) instead of rejecting untied. - The emitter takes a `params['lm_head']` (`[V, H]`) arg right after `final_norm` for an untied config, and routes the final logits projection of all four graph kinds (decode, prefill, batched, ragged) through it via `head_weight` (falling back to `embed` when tied). A tied config emits no such arg, so its graph is unchanged. - `weight_names` adds `lm_head.weight` in the same slot, so the loaded weight buffers line up with the emitted graph args (positional binding). Validation: - Unit: `from_json` accepts untied (and defaults absent to tied); untied adds exactly one `[V, H]` arg after `final_norm` to every graph kind and the tail consumes it; the bundled tied Llama-3.2-1B graphs stay byte-for-byte identical (the existing byte-exact asset test). - E2E on a constructed untied Qwen2.5-0.5B (CUDA, GB10): with `lm_head = embed` the greedy stream is token-exact with the tied HF oracle; with `lm_head = -embed` it is token-exact with its own (divergent) HF oracle, proving the engine consults the separate head rather than the embedding. The ragged serve path is reference-exact at B_max 4 and 8. Sharded-safetensors loading (the remaining gap for the big untied models, which ship as multi-file checkpoints) is a separate follow-up; it also blocks tied big models, so it is orthogonal to this change.
anyhow 1.0.102 is flagged by RUSTSEC-2026-0190 (unsoundness in Error::downcast_mut). The advisory landed 2026-06-29 and fails cargo-deny repo-wide, unrelated to the untied LM head change in this PR; bump to the advisory's recommended 1.0.103 (Cargo.lock only; the manifest constraint anyhow = "1.0.100" already allows it) to restore a green CI gate.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Config::from_jsonhard-rejectedtie_word_embeddings = false, so the OpenXLA backend could only load checkpoints that tie the LM head to the token embedding (Llama-3.2-1B, Qwen2.5-0.5B). This adds an untied path: an untied checkpoint carries a separatelm_head.weightthat feeds the final logits projection instead of the sharedembedmatrix. Unblocks Llama-3.1-8B and the larger Qwen2.5 sizes across both already-supported architectures (Llama, Qwen2), part of #449 M3 Stage 2d.Tied checkpoints are byte-identical to before (the untied arg is only emitted when
tie_word_embeddings = false).What changed
Configgainstie_word_embeddings;from_jsonreads it (absent defaults to tied, the HFPretrainedConfigdefault) instead of rejecting untied.params['lm_head']([V, H]) arg right afterfinal_normfor an untied config, and routes the final logits projection of all four graph kinds (decode, prefill, batched, ragged) through it viahead_weight(falling back toembedwhen tied).weight_namesaddslm_head.weightin the same slot, so the loaded weight buffers line up with the emitted graph args (positional binding).Validation
Unit (
cargo test -p mlxcel-xla, 33 pass):from_jsonaccepts untied and defaults an absent field to tied; untied adds exactly one[V, H]arg afterfinal_normto every graph kind and the tail consumes it; the bundled tied Llama-3.2-1B graphs stay byte-for-byte identical (existing byte-exact asset test).E2E, CUDA on GB10, on a constructed untied Qwen2.5-0.5B (the tied bf16 model with a separate
lm_head.weightadded,tie_word_embeddings=false):lm_head= embed= -embeduntied-copyproves the untied path is correct;untied-neg(negated head flips the argmax, a completely different HF stream) proves the engine genuinely consults the separatelm_headweight rather than the embedding.Follow-up
Sharded-safetensors loading is the remaining gap for the genuinely large untied models (they ship as multi-file checkpoints; the loader reads single-file
model.safetensors). It also blocks tied big models, so it is orthogonal to this change and tracked separately.Refs #449.