Skip to content

feat: load sharded and f16/f32 checkpoints on the OpenXLA path (#449 M3 Stage 2d)#484

Open
inureyes wants to merge 1 commit into
mainfrom
feat/449-xla-sharded-safetensors
Open

feat: load sharded and f16/f32 checkpoints on the OpenXLA path (#449 M3 Stage 2d)#484
inureyes wants to merge 1 commit into
mainfrom
feat/449-xla-sharded-safetensors

Conversation

@inureyes

Copy link
Copy Markdown
Member

Summary

The OpenXLA weight loader read only a single-file model.safetensors and accepted only BF16, so the big checkpoints could not load: they ship as multi-file shards, and most (e.g. Llama-3.1-8B) store F16, not BF16. With untied LM heads now supported (#482), this is the remaining gap before those models run. Part of #449 M3 Stage 2d.

What changed (mlxcel-xla)

  • Sharded loading (iree.rs): when model.safetensors is absent, resolve_weight_shards reads model.safetensors.index.json's weight_map to find each weight's shard. load_weights groups the needed weights by shard, mmaps each shard exactly once, and places the results back into the emitter's arg order by index, so the positional weight→arg binding is unchanged. A dir with both a single file and an index uses the single file.
  • Dtype widening (new weights module): widens BF16 / F16 to f32 (or copies f32). Exact for all three (f32 represents every BF16/F16 value), so it matches HF's own f32 cast. F16 goes through a 65536-entry lookup table built from half_to_f32, so widening a multi-GB checkpoint is memory-bound rather than spending minutes in per-element powi.

Validation

Unit (cargo test -p mlxcel-xla, 38 pass): the weights conversions against reference f16 values (zero, normal, subnormal, max normal, inf, nan) plus bf16 and f32 lanes.

E2E, CUDA on GB10:

Gate Model Result
Split-equivalence Qwen2.5-0.5B, 2-shard copy (same weights, no single file) TOKEN-EXACT vs single-file oracle
Genuine Llama-3.1-8B (F16, 4 shards, untied, llama3 RoPE) TOKEN-EXACT vs HF fp32 oracle

The split-equivalence isolates the shard-read logic on a fast model; the Llama-3.1-8B run exercises sharding + F16 + the untied head at 8B scale together.

Notes

Refs #449.

…M3 Stage 2d)

The OpenXLA weight loader read only a single-file `model.safetensors` and
accepted only BF16, so the big checkpoints could not load: they ship as
multi-file shards, and most (e.g. Llama-3.1-8B) store F16, not BF16. With
untied LM heads now supported (#482), this is the remaining gap before those
models run.

Two additions to the loader, both validated token-exact:

- Sharded loading: when `model.safetensors` is absent, `resolve_weight_shards`
  reads `model.safetensors.index.json`'s `weight_map` to find each weight's
  shard. `load_weights` groups the needed weights by shard, mmaps each shard
  exactly once, and places the results back into the emitter's arg order by
  index, so the positional weight->arg binding is unchanged.
- Dtype widening: a new `weights` module widens BF16 / F16 to f32 (or copies
  f32). The widening is exact for all three (f32 represents every BF16/F16
  value), so it matches HF's own f32 cast. F16 goes through a 65536-entry
  lookup table built from `half_to_f32`, so widening a multi-GB checkpoint is
  memory-bound rather than spending minutes in per-element `powi`.

Validation:
- Unit (`cargo test -p mlxcel-xla`, 38 pass): the `weights` module's
  conversions against reference f16 values (zero, normal, subnormal, max
  normal, inf, nan) plus bf16 and f32 lanes.
- E2E, CUDA on GB10:
  - Split-equivalence: a 2-shard copy of Qwen2.5-0.5B (same weights, no
    single file) is token-exact with the single-file oracle, proving the
    shard read logic.
  - Genuine model: Llama-3.1-8B (F16, 4 shards, untied, llama3 RoPE) is
    token-exact with the HF fp32 oracle, exercising sharding + F16 + the
    untied head at 8B scale together.
@inureyes inureyes added area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions labels Jun 29, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

area:architecture Architecture and code structure changes priority:medium Medium priority status:done Completed type:enhancement New features, capabilities, or significant additions

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant