feat: load sharded and f16/f32 checkpoints on the OpenXLA path (#449 M3 Stage 2d)#484
Open
inureyes wants to merge 1 commit into
Open
feat: load sharded and f16/f32 checkpoints on the OpenXLA path (#449 M3 Stage 2d)#484inureyes wants to merge 1 commit into
inureyes wants to merge 1 commit into
Conversation
…M3 Stage 2d) The OpenXLA weight loader read only a single-file `model.safetensors` and accepted only BF16, so the big checkpoints could not load: they ship as multi-file shards, and most (e.g. Llama-3.1-8B) store F16, not BF16. With untied LM heads now supported (#482), this is the remaining gap before those models run. Two additions to the loader, both validated token-exact: - Sharded loading: when `model.safetensors` is absent, `resolve_weight_shards` reads `model.safetensors.index.json`'s `weight_map` to find each weight's shard. `load_weights` groups the needed weights by shard, mmaps each shard exactly once, and places the results back into the emitter's arg order by index, so the positional weight->arg binding is unchanged. - Dtype widening: a new `weights` module widens BF16 / F16 to f32 (or copies f32). The widening is exact for all three (f32 represents every BF16/F16 value), so it matches HF's own f32 cast. F16 goes through a 65536-entry lookup table built from `half_to_f32`, so widening a multi-GB checkpoint is memory-bound rather than spending minutes in per-element `powi`. Validation: - Unit (`cargo test -p mlxcel-xla`, 38 pass): the `weights` module's conversions against reference f16 values (zero, normal, subnormal, max normal, inf, nan) plus bf16 and f32 lanes. - E2E, CUDA on GB10: - Split-equivalence: a 2-shard copy of Qwen2.5-0.5B (same weights, no single file) is token-exact with the single-file oracle, proving the shard read logic. - Genuine model: Llama-3.1-8B (F16, 4 shards, untied, llama3 RoPE) is token-exact with the HF fp32 oracle, exercising sharding + F16 + the untied head at 8B scale together.
This was referenced Jun 29, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
The OpenXLA weight loader read only a single-file
model.safetensorsand accepted only BF16, so the big checkpoints could not load: they ship as multi-file shards, and most (e.g. Llama-3.1-8B) store F16, not BF16. With untied LM heads now supported (#482), this is the remaining gap before those models run. Part of #449 M3 Stage 2d.What changed (
mlxcel-xla)iree.rs): whenmodel.safetensorsis absent,resolve_weight_shardsreadsmodel.safetensors.index.json'sweight_mapto find each weight's shard.load_weightsgroups the needed weights by shard, mmaps each shard exactly once, and places the results back into the emitter's arg order by index, so the positional weight→arg binding is unchanged. A dir with both a single file and an index uses the single file.weightsmodule): widens BF16 / F16 to f32 (or copies f32). Exact for all three (f32 represents every BF16/F16 value), so it matches HF's own f32 cast. F16 goes through a 65536-entry lookup table built fromhalf_to_f32, so widening a multi-GB checkpoint is memory-bound rather than spending minutes in per-elementpowi.Validation
Unit (
cargo test -p mlxcel-xla, 38 pass): theweightsconversions against reference f16 values (zero, normal, subnormal, max normal, inf, nan) plus bf16 and f32 lanes.E2E, CUDA on GB10:
The split-equivalence isolates the shard-read logic on a fast model; the Llama-3.1-8B run exercises sharding + F16 + the untied head at 8B scale together.
Notes
main(independent of the untied diff; this PR only touches the loader).Refs #449.