Skip to content

feat: serve untied LM heads on the OpenXLA path (separate lm_head weight) (#449 M3 Stage 2d)#482

Merged
inureyes merged 2 commits into
mainfrom
feat/449-xla-untied-lm-head
Jun 29, 2026
Merged

feat: serve untied LM heads on the OpenXLA path (separate lm_head weight) (#449 M3 Stage 2d)#482
inureyes merged 2 commits into
mainfrom
feat/449-xla-untied-lm-head

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

Config::from_json hard-rejected tie_word_embeddings = false, so the OpenXLA backend could only load checkpoints that tie the LM head to the token embedding (Llama-3.2-1B, Qwen2.5-0.5B). This adds an untied path: an untied checkpoint carries a separate lm_head.weight that feeds the final logits projection instead of the shared embed matrix. Unblocks Llama-3.1-8B and the larger Qwen2.5 sizes across both already-supported architectures (Llama, Qwen2), part of #449 M3 Stage 2d.

Tied checkpoints are byte-identical to before (the untied arg is only emitted when tie_word_embeddings = false).

What changed

  • Config gains tie_word_embeddings; from_json reads it (absent defaults to tied, the HF PretrainedConfig default) instead of rejecting untied.
  • Emitter takes a params['lm_head'] ([V, H]) arg right after final_norm for an untied config, and routes the final logits projection of all four graph kinds (decode, prefill, batched, ragged) through it via head_weight (falling back to embed when tied).
  • weight_names adds lm_head.weight in the same slot, so the loaded weight buffers line up with the emitted graph args (positional binding).

Validation

Unit (cargo test -p mlxcel-xla, 33 pass): from_json accepts untied and defaults an absent field to tied; untied adds exactly one [V, H] arg after final_norm to every graph kind and the tail consumes it; the bundled tied Llama-3.2-1B graphs stay byte-for-byte identical (existing byte-exact asset test).

E2E, CUDA on GB10, on a constructed untied Qwen2.5-0.5B (the tied bf16 model with a separate lm_head.weight added, tie_word_embeddings=false):

Model lm_head Single-seq vs HF oracle Serve (ragged)
tied original (shares embed) TOKEN-EXACT (regression) n/a
untied-copy = embed TOKEN-EXACT (== tied oracle) REFERENCE-EXACT B_max 4 & 8
untied-neg = -embed TOKEN-EXACT (== its own divergent oracle) n/a

untied-copy proves the untied path is correct; untied-neg (negated head flips the argmax, a completely different HF stream) proves the engine genuinely consults the separate lm_head weight rather than the embedding.

Follow-up

Sharded-safetensors loading is the remaining gap for the genuinely large untied models (they ship as multi-file checkpoints; the loader reads single-file model.safetensors). It also blocks tied big models, so it is orthogonal to this change and tracked separately.

Refs #449.

…ght) (#449 M3 Stage 2d)

`Config::from_json` hard-rejected `tie_word_embeddings = false`, so the
OpenXLA backend could only load checkpoints that tie the LM head to the
token embedding (Llama-3.2-1B, Qwen2.5-0.5B). That blocked every untied
model, including Llama-3.1-8B and the larger Qwen2.5 sizes, whose head is
a separate `lm_head.weight`.

Add an untied path that the existing tied checkpoints are byte-identical
across:

- `Config` gains `tie_word_embeddings`; `from_json` reads it (absent
  defaults to tied, the HF `PretrainedConfig` default) instead of
  rejecting untied.
- The emitter takes a `params['lm_head']` (`[V, H]`) arg right after
  `final_norm` for an untied config, and routes the final logits
  projection of all four graph kinds (decode, prefill, batched, ragged)
  through it via `head_weight` (falling back to `embed` when tied). A
  tied config emits no such arg, so its graph is unchanged.
- `weight_names` adds `lm_head.weight` in the same slot, so the loaded
  weight buffers line up with the emitted graph args (positional binding).

Validation:
- Unit: `from_json` accepts untied (and defaults absent to tied); untied
  adds exactly one `[V, H]` arg after `final_norm` to every graph kind and
  the tail consumes it; the bundled tied Llama-3.2-1B graphs stay
  byte-for-byte identical (the existing byte-exact asset test).
- E2E on a constructed untied Qwen2.5-0.5B (CUDA, GB10): with
  `lm_head = embed` the greedy stream is token-exact with the tied HF
  oracle; with `lm_head = -embed` it is token-exact with its own
  (divergent) HF oracle, proving the engine consults the separate head
  rather than the embedding. The ragged serve path is reference-exact at
  B_max 4 and 8.

Sharded-safetensors loading (the remaining gap for the big untied models,
which ship as multi-file checkpoints) is a separate follow-up; it also
blocks tied big models, so it is orthogonal to this change.
@inureyes inureyes added area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions labels Jun 29, 2026
anyhow 1.0.102 is flagged by RUSTSEC-2026-0190 (unsoundness in Error::downcast_mut). The advisory landed 2026-06-29 and fails cargo-deny repo-wide, unrelated to the untied LM head change in this PR; bump to the advisory's recommended 1.0.103 (Cargo.lock only; the manifest constraint anyhow = "1.0.100" already allows it) to restore a green CI gate.
@inureyes inureyes merged commit 8b3ded4 into main Jun 29, 2026
5 checks passed
@inureyes inureyes deleted the feat/449-xla-untied-lm-head branch June 29, 2026 16:49
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant