From 281a035e1a93f22623064479d0d8576ca660d917 Mon Sep 17 00:00:00 2001 From: Jordan Anderson Date: Fri, 5 Jun 2026 15:26:10 -0500 Subject: [PATCH 1/3] Add Experts4bit for 4-bit quantization of fused MoE experts (#1849) transformers v5 stores fused MoE experts as a single 3D nn.Parameter (e.g. OlmoeExperts, Qwen3MoeExperts), which the nn.Linear-based 4-bit walker skips. The experts stay in full precision and load_in_4bit barely shrinks the model (issue #1849). Experts4bit holds gate_up_proj and down_proj packed in NF4/FP4 as plain nn.Parameter buffers, with per-expert absmax kept on the module itself. The forward pass dequantizes one expert at a time (a per-expert loop), mirroring the reference fused-experts forward. There is no Params4bit tensor-subclass machinery, so the module serializes through the default state_dict with no custom hooks. - from_float() quantizes existing bf16/fp16 expert stacks - enforces in_features % blocksize == 0 for clean per-expert blocking - double-quant (compress_statistics) and grouped-GEMM intentionally deferred for a first cut - tests: quant round-trip, forward vs. full-precision reference, state_dict round-trip, and validation guards --- bitsandbytes/nn/__init__.py | 1 + bitsandbytes/nn/experts.py | 258 ++++++++++++++++++++++++++++++++++++ tests/test_experts4bit.py | 156 ++++++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 bitsandbytes/nn/experts.py create mode 100644 tests/test_experts4bit.py diff --git a/bitsandbytes/nn/__init__.py b/bitsandbytes/nn/__init__.py index 54c2614bd..225146494 100644 --- a/bitsandbytes/nn/__init__.py +++ b/bitsandbytes/nn/__init__.py @@ -2,6 +2,7 @@ # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. +from .experts import Experts4bit from .modules import ( Embedding, Embedding4bit, diff --git a/bitsandbytes/nn/experts.py b/bitsandbytes/nn/experts.py new file mode 100644 index 000000000..bd2f355dc --- /dev/null +++ b/bitsandbytes/nn/experts.py @@ -0,0 +1,258 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from collections.abc import Callable +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F_nn + +import bitsandbytes.functional as F +from bitsandbytes.functional import QuantState + + +class Experts4bit(nn.Module): + """4-bit quantized storage for fused Mixture-of-Experts expert weights. + + A growing number of models in the Hugging Face ecosystem store their MoE expert + weights as a single 3D ``nn.Parameter`` of shape ``[num_experts, out_features, + in_features]`` (e.g. ``OlmoeExperts``, ``Qwen3MoeExperts``) rather than as a + collection of ``nn.Linear`` layers. The default 4-bit quantization walker only + replaces ``nn.Linear`` modules, so these fused experts are silently skipped and + stay in full precision — the dominant contribution to the model's memory footprint + (see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/1849). + + ``Experts4bit`` holds the two expert projections (``gate_up_proj`` and ``down_proj``) + in 4-bit NF4/FP4 precision. Unlike :class:`Linear4bit`, the packed weights are kept + as plain ``nn.Parameter`` buffers and the per-expert quantization statistics + (``absmax``) live on the module as ordinary buffers. This avoids bending + :class:`Params4bit`'s tensor-subclass and device-movement machinery around a 3D + stack, and it means the module serializes through the standard ``state_dict`` + mechanism with no custom save/load hooks. + + The forward pass dequantizes a single expert at a time (a per-expert loop), mirroring + the reference fused-experts forward. Grouped-GEMM is intentionally left for future + work. + + This feature is experimental and may change in future releases. + + Args: + num_experts (`int`): Number of experts in the layer. + hidden_dim (`int`): Model hidden size (the ``in_features`` of ``gate_up_proj`` and + the ``out_features`` of ``down_proj``). + intermediate_dim (`int`): Expert intermediate size (the ``in_features`` of + ``down_proj``). + has_gate (`bool`, *optional*, defaults to `True`): Whether ``gate_up_proj`` packs a + gate and an up projection (SwiGLU-style). When `False`, the projection is a + plain up projection of size ``intermediate_dim``. + activation (`Callable`, *optional*): The activation applied to the gate. Defaults + to ``torch.nn.functional.silu`` (SwiGLU), matching OLMoE / Qwen3-MoE. + compute_dtype (`torch.dtype`, *optional*): The dtype expert weights are + dequantized to for the matmul. When `None`, the input's dtype is used. + quant_type (`str`, *optional*, defaults to `"nf4"`): The 4-bit data type, ``nf4`` + or ``fp4``. + blocksize (`int`, *optional*, defaults to `64`): The quantization block size. + device (*optional*): The device for the (empty) packed buffers. + + Raises: + ValueError: If ``quant_type`` is invalid, or if ``hidden_dim`` / ``intermediate_dim`` + is not divisible by ``blocksize`` (required so per-expert quantization blocks + never straddle an expert boundary). + """ + + def __init__( + self, + num_experts: int, + hidden_dim: int, + intermediate_dim: int, + has_gate: bool = True, + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + compute_dtype: Optional[torch.dtype] = None, + quant_type: str = "nf4", + blocksize: int = 64, + device=None, + ): + super().__init__() + + if quant_type not in ("nf4", "fp4"): + raise ValueError(f"quant_type must be 'nf4' or 'fp4', got {quant_type!r}") + + # Each expert is quantized independently, so an expert occupies a contiguous + # `out_features * in_features` run of elements. Requiring the in_features dim to + # be a multiple of the blocksize guarantees `out_features * in_features` is too, + # so blocks tile each expert exactly and absmax reshapes cleanly to + # [num_experts, blocks_per_expert]. (gate_up in_features is hidden_dim; down_proj + # in_features is intermediate_dim.) + for name, in_features in (("hidden_dim", hidden_dim), ("intermediate_dim", intermediate_dim)): + if in_features % blocksize != 0: + raise ValueError( + f"{name} ({in_features}) must be divisible by blocksize ({blocksize}) " + "so per-expert quantization blocks align with expert boundaries" + ) + + self.num_experts = num_experts + self.hidden_dim = hidden_dim + self.intermediate_dim = intermediate_dim + self.has_gate = has_gate + self.act_fn = activation if activation is not None else F_nn.silu + self.compute_dtype = compute_dtype + self.quant_type = quant_type + self.blocksize = blocksize + + gate_up_out = 2 * intermediate_dim if has_gate else intermediate_dim + self._gate_up_shape = (gate_up_out, hidden_dim) + self._down_shape = (hidden_dim, intermediate_dim) + + gate_up_numel = gate_up_out * hidden_dim + down_numel = hidden_dim * intermediate_dim + + # Packed 4-bit weights as plain (frozen) parameters: two 4-bit values per byte. + self.gate_up_proj = nn.Parameter( + torch.empty(num_experts, gate_up_numel // 2, dtype=torch.uint8, device=device), + requires_grad=False, + ) + self.down_proj = nn.Parameter( + torch.empty(num_experts, down_numel // 2, dtype=torch.uint8, device=device), + requires_grad=False, + ) + + # Per-expert quantization scales. + self.register_buffer( + "gate_up_absmax", + torch.empty(num_experts, gate_up_numel // blocksize, dtype=torch.float32, device=device), + ) + self.register_buffer( + "down_absmax", + torch.empty(num_experts, down_numel // blocksize, dtype=torch.float32, device=device), + ) + + # The 4-bit codebook is identical for every expert and fully determined by + # quant_type, so it is reconstructed at init rather than serialized. + self.register_buffer("code", F.get_4bit_type(quant_type, device=device), persistent=False) + + @classmethod + def from_float( + cls, + gate_up_proj: torch.Tensor, + down_proj: torch.Tensor, + has_gate: bool = True, + activation: Optional[Callable[[torch.Tensor], torch.Tensor]] = None, + compute_dtype: Optional[torch.dtype] = None, + quant_type: str = "nf4", + blocksize: int = 64, + ) -> "Experts4bit": + """Build an :class:`Experts4bit` by quantizing full-precision expert weights. + + Args: + gate_up_proj (`torch.Tensor`): Shape ``[num_experts, gate_up_out, hidden_dim]``, + where ``gate_up_out`` is ``2 * intermediate_dim`` when ``has_gate`` else + ``intermediate_dim``. + down_proj (`torch.Tensor`): Shape ``[num_experts, hidden_dim, intermediate_dim]``. + + Returns: + `Experts4bit`: A module holding the quantized weights on the inputs' device. + """ + if gate_up_proj.dim() != 3 or down_proj.dim() != 3: + raise ValueError("gate_up_proj and down_proj must be 3D [num_experts, out, in] tensors") + + num_experts, _, hidden_dim = gate_up_proj.shape + intermediate_dim = down_proj.shape[2] + + module = cls( + num_experts, + hidden_dim, + intermediate_dim, + has_gate=has_gate, + activation=activation, + compute_dtype=compute_dtype if compute_dtype is not None else gate_up_proj.dtype, + quant_type=quant_type, + blocksize=blocksize, + device=gate_up_proj.device, + ) + + gate_up_packed, gate_up_absmax = module._quantize_stack(gate_up_proj) + down_packed, down_absmax = module._quantize_stack(down_proj) + + module.gate_up_proj = nn.Parameter(gate_up_packed, requires_grad=False) + module.down_proj = nn.Parameter(down_packed, requires_grad=False) + module.gate_up_absmax = gate_up_absmax + module.down_absmax = down_absmax + return module + + def _quantize_stack(self, weights: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + """Quantize a ``[num_experts, out, in]`` stack to packed bytes + per-expert absmax.""" + packed = [] + absmax = [] + for e in range(weights.shape[0]): + q, state = F.quantize_4bit( + weights[e].contiguous(), + blocksize=self.blocksize, + compress_statistics=False, + quant_type=self.quant_type, + ) + packed.append(q.reshape(-1)) + absmax.append(state.absmax.reshape(-1)) + return torch.stack(packed), torch.stack(absmax) + + def _dequantize_expert( + self, + packed: torch.Tensor, + absmax: torch.Tensor, + shape: tuple[int, int], + expert_idx: int, + dtype: torch.dtype, + ) -> torch.Tensor: + """Dequantize a single expert's 2D weight ``[out, in]`` for the matmul.""" + quant_state = QuantState( + absmax=absmax[expert_idx], + shape=torch.Size(shape), + code=self.code, + blocksize=self.blocksize, + quant_type=self.quant_type, + dtype=dtype, + ) + # Restore the [packed, 1] layout quantize_4bit emits (and which keeps the + # transpose back-compat shim — keyed on A.shape[0] == 1 — from firing). + return F.dequantize_4bit(packed[expert_idx].reshape(-1, 1), quant_state=quant_state) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + compute_dtype = self.compute_dtype if self.compute_dtype is not None else hidden_states.dtype + hidden_states = hidden_states.to(compute_dtype) + + # Accumulate in float32 for numerical stability with bf16/fp16 routing weights. + final_hidden_states = torch.zeros_like(hidden_states, dtype=torch.float32) + + with torch.no_grad(): + expert_mask = F_nn.one_hot(top_k_index, num_classes=self.num_experts) + expert_mask = expert_mask.permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1) + + for expert_idx in expert_hit: + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + + gate_up_w = self._dequantize_expert( + self.gate_up_proj, self.gate_up_absmax, self._gate_up_shape, expert_idx, compute_dtype + ) + proj = F_nn.linear(current_state, gate_up_w) + if self.has_gate: + gate, up = proj.chunk(2, dim=-1) + current_hidden = self.act_fn(gate) * up + else: + current_hidden = self.act_fn(proj) + + down_w = self._dequantize_expert( + self.down_proj, self.down_absmax, self._down_shape, expert_idx, compute_dtype + ) + current_hidden = F_nn.linear(current_hidden, down_w) + current_hidden = current_hidden * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden.to(final_hidden_states.dtype)) + + return final_hidden_states.to(hidden_states.dtype) diff --git a/tests/test_experts4bit.py b/tests/test_experts4bit.py new file mode 100644 index 000000000..35204a342 --- /dev/null +++ b/tests/test_experts4bit.py @@ -0,0 +1,156 @@ +import pytest +import torch + +import bitsandbytes as bnb +from bitsandbytes.nn import Experts4bit +from tests.helpers import describe_dtype, get_available_devices, id_formatter + +# Small but representative MoE dims. hidden_dim and intermediate_dim are both multiples +# of the default blocksize (64), as required by Experts4bit. +NUM_EXPERTS = 4 +HIDDEN_DIM = 64 +INTERMEDIATE_DIM = 128 +TOP_K = 2 +NUM_TOKENS = 12 + + +def _random_expert_weights(dtype, device, has_gate=True): + gate_up_out = 2 * INTERMEDIATE_DIM if has_gate else INTERMEDIATE_DIM + gate_up = torch.randn(NUM_EXPERTS, gate_up_out, HIDDEN_DIM, dtype=dtype, device=device) * 0.1 + down = torch.randn(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM, dtype=dtype, device=device) * 0.1 + return gate_up, down + + +def _random_routing(device): + hidden_states = torch.randn(NUM_TOKENS, HIDDEN_DIM, device=device) + top_k_index = torch.randint(0, NUM_EXPERTS, (NUM_TOKENS, TOP_K), device=device) + top_k_weights = torch.softmax(torch.randn(NUM_TOKENS, TOP_K, device=device), dim=-1) + return hidden_states, top_k_index, top_k_weights + + +def _reference_forward(gate_up, down, hidden_states, top_k_index, top_k_weights, act_fn=torch.nn.functional.silu): + """Plain full-precision fused-experts forward (mirrors OlmoeExperts.forward).""" + compute_dtype = gate_up.dtype + hidden_states = hidden_states.to(compute_dtype) + final = torch.zeros_like(hidden_states, dtype=torch.float32) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=gate_up.shape[0]).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1) + for expert_idx in expert_hit: + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + current_state = hidden_states[token_idx] + gate, up = torch.nn.functional.linear(current_state, gate_up[expert_idx]).chunk(2, dim=-1) + current = act_fn(gate) * up + current = torch.nn.functional.linear(current, down[expert_idx]) + current = current * top_k_weights[token_idx, top_k_pos, None] + final.index_add_(0, token_idx, current.to(final.dtype)) + return final.to(hidden_states.dtype) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +@pytest.mark.parametrize("quant_type", ["nf4", "fp4"]) +def test_experts4bit_roundtrip(device, dtype, quant_type): + gate_up, down = _random_expert_weights(dtype, device) + module = Experts4bit.from_float(gate_up, down, quant_type=quant_type) + + # Packed-weight and absmax shapes/dtypes. + gate_up_out = 2 * INTERMEDIATE_DIM + assert module.gate_up_proj.dtype == torch.uint8 + assert module.gate_up_proj.shape == (NUM_EXPERTS, gate_up_out * HIDDEN_DIM // 2) + assert module.down_proj.shape == (NUM_EXPERTS, HIDDEN_DIM * INTERMEDIATE_DIM // 2) + assert module.gate_up_absmax.shape == (NUM_EXPERTS, gate_up_out * HIDDEN_DIM // module.blocksize) + assert not module.gate_up_proj.requires_grad + + # Per-expert dequantization round-trips within 4-bit tolerance. + for e in range(NUM_EXPERTS): + deq = module._dequantize_expert(module.gate_up_proj, module.gate_up_absmax, module._gate_up_shape, e, dtype) + assert deq.shape == (gate_up_out, HIDDEN_DIM) + assert deq.dtype == dtype + torch.testing.assert_close(deq.float(), gate_up[e].float(), rtol=0.15, atol=0.05) + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("has_gate", [True, False], ids=id_formatter("has_gate")) +def test_experts4bit_forward_matches_reference(device, has_gate): + # float32 compute so the only difference vs. the reference is float accumulation order. + gate_up, down = _random_expert_weights(torch.float32, device, has_gate=has_gate) + module = Experts4bit.from_float(gate_up, down, has_gate=has_gate, compute_dtype=torch.float32) + + hidden_states, top_k_index, top_k_weights = _random_routing(device) + + # Reference uses the exact weights the module holds internally (dequantized bytes), + # isolating forward/routing correctness from quantization error. + gate_up_deq = torch.stack( + [ + module._dequantize_expert( + module.gate_up_proj, module.gate_up_absmax, module._gate_up_shape, e, torch.float32 + ) + for e in range(NUM_EXPERTS) + ] + ) + down_deq = torch.stack( + [ + module._dequantize_expert(module.down_proj, module.down_absmax, module._down_shape, e, torch.float32) + for e in range(NUM_EXPERTS) + ] + ) + + if has_gate: + ref = _reference_forward(gate_up_deq, down_deq, hidden_states, top_k_index, top_k_weights) + else: + # no-gate reference: act_fn applied to the whole projection + ref = torch.zeros_like(hidden_states, dtype=torch.float32) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=NUM_EXPERTS).permute(2, 1, 0) + for expert_idx in torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1): + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + cur = torch.nn.functional.silu( + torch.nn.functional.linear(hidden_states[token_idx], gate_up_deq[expert_idx]) + ) + cur = torch.nn.functional.linear(cur, down_deq[expert_idx]) + cur = cur * top_k_weights[token_idx, top_k_pos, None] + ref.index_add_(0, token_idx, cur) + ref = ref.to(hidden_states.dtype) + + out = module(hidden_states, top_k_index, top_k_weights) + assert out.shape == hidden_states.shape + torch.testing.assert_close(out, ref, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_experts4bit_state_dict_roundtrip(device): + gate_up, down = _random_expert_weights(torch.float16, device) + module = Experts4bit.from_float(gate_up, down, compute_dtype=torch.float16) + + # Default state_dict carries everything (plain Parameters + buffers — no custom hooks). + sd = module.state_dict() + assert "gate_up_proj" in sd and "down_proj" in sd + assert "gate_up_absmax" in sd and "down_absmax" in sd + assert "code" not in sd # codebook is non-persistent (reconstructed at init) + + reloaded = Experts4bit(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM, compute_dtype=torch.float16, device=device) + missing, unexpected = reloaded.load_state_dict(sd, strict=True), None + assert missing.missing_keys == [] and missing.unexpected_keys == [] + + # Bit-exact restore of packed weights + absmax. + torch.testing.assert_close(reloaded.gate_up_proj, module.gate_up_proj, rtol=0, atol=0) + torch.testing.assert_close(reloaded.down_absmax, module.down_absmax, rtol=0, atol=0) + + # Identical forward after reload. + hidden_states, top_k_index, top_k_weights = _random_routing(device) + out_a = module(hidden_states, top_k_index, top_k_weights) + out_b = reloaded(hidden_states, top_k_index, top_k_weights) + torch.testing.assert_close(out_a, out_b, rtol=0, atol=0) + + +def test_experts4bit_blocksize_validation(): + # in_features (hidden_dim / intermediate_dim) must be divisible by blocksize. + with pytest.raises(ValueError, match="divisible by blocksize"): + Experts4bit(NUM_EXPERTS, hidden_dim=100, intermediate_dim=128, blocksize=64) + with pytest.raises(ValueError, match="divisible by blocksize"): + Experts4bit(NUM_EXPERTS, hidden_dim=64, intermediate_dim=100, blocksize=64) + with pytest.raises(ValueError, match="quant_type"): + Experts4bit(NUM_EXPERTS, HIDDEN_DIM, INTERMEDIATE_DIM, quant_type="int4") + + +def test_experts4bit_is_exported(): + assert bnb.nn.Experts4bit is Experts4bit From 5a0c73add706331092e373ce21d602560e49799e Mon Sep 17 00:00:00 2001 From: Jordan Anderson Date: Wed, 10 Jun 2026 12:04:05 -0500 Subject: [PATCH 2/3] Add Experts4bit reference docs --- docs/source/_toctree.yml | 2 ++ docs/source/reference/nn/experts.mdx | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 docs/source/reference/nn/experts.mdx diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index 0f46fe6b0..f47eb24de 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -59,5 +59,7 @@ title: LLM.int8() - local: reference/nn/linear4bit title: 4-bit quantizer + - local: reference/nn/experts + title: 4-bit MoE experts - local: reference/nn/embeddings title: Embedding diff --git a/docs/source/reference/nn/experts.mdx b/docs/source/reference/nn/experts.mdx new file mode 100644 index 000000000..18478b0d7 --- /dev/null +++ b/docs/source/reference/nn/experts.mdx @@ -0,0 +1,24 @@ +# 4-bit MoE experts + +Some Mixture-of-Experts (MoE) models store their expert weights as a single fused 3D parameter of shape `[num_experts, out_features, in_features]` (for example `OlmoeExperts` and `Qwen3MoeExperts` in Transformers) instead of a collection of `nn.Linear` layers. The `nn.Linear`-based 4-bit replacement path skips these fused parameters, leaving the experts — typically the bulk of the model's weights — in full precision. + +`Experts4bit` stores the fused `gate_up_proj` and `down_proj` expert stacks in 4-bit (NF4 or FP4) precision with per-expert quantization statistics, and dequantizes one expert at a time during the forward pass. + +```py +from bitsandbytes.nn import Experts4bit + +# Quantize an existing fp16/bf16 fused-expert stack: +experts = Experts4bit.from_float(gate_up_proj, down_proj, quant_type="nf4") +out = experts(hidden_states, top_k_index, top_k_weights) + +# Or construct empty and load a pre-quantized checkpoint: +experts = Experts4bit(num_experts, hidden_dim, intermediate_dim) +experts.load_state_dict(sd) +``` + +## Experts4bit + +[[autodoc]] bitsandbytes.nn.Experts4bit + - __init__ + - from_float + - forward From 7a2b9fdec7d1b881ae2ad14009fe7750a1008988 Mon Sep 17 00:00:00 2001 From: pjordanandrsn Date: Tue, 30 Jun 2026 11:18:14 +0000 Subject: [PATCH 3/3] test: add backward + QLoRA-training coverage and demo for Experts4bit Prove Experts4bit works as a frozen 4-bit QLoRA base. New tests in tests/test_experts4bit.py cover the autograd contract: gradients reach the input activations, the frozen packed weights never receive a gradient, and the backward matches a full-precision reference forward. Add examples/experts4bit_qlora_demo.py with a small per-expert ExpertsLoRA wrapper over a frozen base, plus a test that a real optimizer step reduces loss while the 4-bit base stays bit-identical. The wrapper is a reference pattern (PEFT/Unsloth territory), intentionally not part of the bitsandbytes public API. Co-Authored-By: Claude Opus 4.8 --- examples/experts4bit_qlora_demo.py | 154 +++++++++++++++++++++++++++++ tests/test_experts4bit.py | 138 ++++++++++++++++++++++++++ 2 files changed, 292 insertions(+) create mode 100644 examples/experts4bit_qlora_demo.py diff --git a/examples/experts4bit_qlora_demo.py b/examples/experts4bit_qlora_demo.py new file mode 100644 index 000000000..e5819c38e --- /dev/null +++ b/examples/experts4bit_qlora_demo.py @@ -0,0 +1,154 @@ +"""QLoRA-style training of fused MoE experts on a frozen ``Experts4bit`` base. + +This is a *reference pattern*, intentionally **not** part of the bitsandbytes public API. It +shows that the ``Experts4bit`` 4-bit storage primitive can serve as the frozen base of a +QLoRA-style fine-tune of fused Mixture-of-Experts weights: the 4-bit expert weights stay +frozen, and small per-expert low-rank (LoRA) adapters are the only trainable parameters. + +The adapter wiring shown here is the kind of thing that would ultimately live in PEFT / +Unsloth rather than in bitsandbytes itself — the point of this file is to demonstrate that +the base primitive is already differentiable and trainable as a frozen base today. + +Run: + python examples/experts4bit_qlora_demo.py +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from bitsandbytes.nn import Experts4bit + + +class ExpertsLoRA(nn.Module): + """Per-expert LoRA adapters over a frozen :class:`Experts4bit` base. + + For each expert ``e``, the two frozen 4-bit projections are augmented with a trainable + low-rank term ``scaling * (x @ A[e].T) @ B[e].T``: + + * ``gate_up``: ``A[e]`` is ``[r, hidden]``, ``B[e]`` is ``[gate_up_out, r]`` + * ``down``: ``A[e]`` is ``[r, intermediate]``, ``B[e]`` is ``[hidden, r]`` + + ``B`` is initialised to zero, so the adapted module is identical to the frozen base at + step 0 and only departs from it as the adapters train (standard LoRA initialisation). + """ + + def __init__(self, base: Experts4bit, r: int = 8, alpha: int = 16, dtype: torch.dtype = torch.float32): + super().__init__() + self.base = base + for p in self.base.parameters(): + p.requires_grad_(False) + + self.r = r + self.scaling = alpha / r + + num_experts = base.num_experts + gate_up_out, hidden = base._gate_up_shape # [2*intermediate (or intermediate), hidden] + _, intermediate = base._down_shape # [hidden, intermediate] + + self.gate_up_lora_A = nn.Parameter(torch.empty(num_experts, r, hidden, dtype=dtype)) + self.gate_up_lora_B = nn.Parameter(torch.zeros(num_experts, gate_up_out, r, dtype=dtype)) + self.down_lora_A = nn.Parameter(torch.empty(num_experts, r, intermediate, dtype=dtype)) + self.down_lora_B = nn.Parameter(torch.zeros(num_experts, hidden, r, dtype=dtype)) + + # A ~ small random, B = 0 => the initial LoRA delta is exactly zero. + nn.init.normal_(self.gate_up_lora_A, std=1.0 / r) + nn.init.normal_(self.down_lora_A, std=1.0 / r) + + def _lora(self, x: torch.Tensor, A: torch.Tensor, B: torch.Tensor) -> torch.Tensor: + # x: [n, in]; A: [r, in]; B: [out, r] -> [n, out] + return self.scaling * F.linear(F.linear(x, A), B) + + def forward( + self, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, + ) -> torch.Tensor: + base = self.base + compute_dtype = base.compute_dtype if base.compute_dtype is not None else hidden_states.dtype + hidden_states = hidden_states.to(compute_dtype) + + final_hidden_states = torch.zeros_like(hidden_states, dtype=torch.float32) + + with torch.no_grad(): + expert_mask = F.one_hot(top_k_index, num_classes=base.num_experts).permute(2, 1, 0) + expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero(as_tuple=False).view(-1) + + for expert_idx in expert_hit: + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) + x = hidden_states[token_idx] + + # Frozen 4-bit base projection + trainable low-rank delta. + gate_up_w = base._dequantize_expert( + base.gate_up_proj, base.gate_up_absmax, base._gate_up_shape, expert_idx, compute_dtype + ) + proj = F.linear(x, gate_up_w) + self._lora( + x, self.gate_up_lora_A[expert_idx], self.gate_up_lora_B[expert_idx] + ) + + if base.has_gate: + gate, up = proj.chunk(2, dim=-1) + current_hidden = base.act_fn(gate) * up + else: + current_hidden = base.act_fn(proj) + + down_w = base._dequantize_expert( + base.down_proj, base.down_absmax, base._down_shape, expert_idx, compute_dtype + ) + current_hidden = F.linear(current_hidden, down_w) + self._lora( + current_hidden, self.down_lora_A[expert_idx], self.down_lora_B[expert_idx] + ) + + current_hidden = current_hidden * top_k_weights[token_idx, top_k_pos, None] + final_hidden_states.index_add_(0, token_idx, current_hidden.to(final_hidden_states.dtype)) + + return final_hidden_states.to(hidden_states.dtype) + + +def main() -> None: + torch.manual_seed(0) + + num_experts, hidden, intermediate = 8, 128, 256 + num_tokens, top_k = 64, 2 + + # A full-precision fused-expert stack (the shape transformers v5 stores MoE experts in). + gate_up = torch.randn(num_experts, 2 * intermediate, hidden) * 0.1 + down = torch.randn(num_experts, hidden, intermediate) * 0.1 + + # Freeze it in 4-bit, then attach trainable LoRA adapters. + base = Experts4bit.from_float(gate_up, down, quant_type="nf4", compute_dtype=torch.float32) + model = ExpertsLoRA(base, r=8, alpha=16) + + trainable = [p for p in model.parameters() if p.requires_grad] + n_train = sum(p.numel() for p in trainable) + n_base_bytes = base.gate_up_proj.numel() + base.down_proj.numel() + print(f"trainable LoRA params: {n_train:,} frozen packed base bytes: {n_base_bytes:,}") + + hidden_states = torch.randn(num_tokens, hidden) + top_k_index = torch.randint(0, num_experts, (num_tokens, top_k)) + top_k_weights = torch.softmax(torch.randn(num_tokens, top_k), dim=-1) + target = torch.randn(num_tokens, hidden) + + gate_up_before = base.gate_up_proj.clone() + + optimizer = torch.optim.Adam(trainable, lr=1e-2) + print("\nstep loss") + for step in range(50): + optimizer.zero_grad() + out = model(hidden_states, top_k_index, top_k_weights) + loss = F.mse_loss(out, target) + loss.backward() + assert base.gate_up_proj.grad is None, "frozen base must never receive a gradient" + optimizer.step() + if step % 10 == 0 or step == 49: + print(f"{step:4d} {loss.item():.5f}") + + assert torch.equal(base.gate_up_proj, gate_up_before), "frozen base bytes must be unchanged" + print("\nbase packed weights unchanged after training:", torch.equal(base.gate_up_proj, gate_up_before)) + + +if __name__ == "__main__": + main() diff --git a/tests/test_experts4bit.py b/tests/test_experts4bit.py index 35204a342..cf3c640b9 100644 --- a/tests/test_experts4bit.py +++ b/tests/test_experts4bit.py @@ -154,3 +154,141 @@ def test_experts4bit_blocksize_validation(): def test_experts4bit_is_exported(): assert bnb.nn.Experts4bit is Experts4bit + + +# --- Backward / autograd --------------------------------------------------------------- +# Experts4bit is a *frozen* 4-bit base: the packed weights are requires_grad=False, so they +# never receive gradients, but the per-expert dequant + linear + index_add_ forward is fully +# differentiable w.r.t. the input activations. That makes it usable as the frozen base of a +# QLoRA-style setup (gradients flow to adapters/earlier layers, not to the quantized weights). +# These tests lock that contract in. + + +@pytest.mark.parametrize("device", get_available_devices()) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype) +def test_experts4bit_backward_flows_to_input(device, dtype): + gate_up, down = _random_expert_weights(dtype, device) + module = Experts4bit.from_float(gate_up, down, compute_dtype=dtype) + + hidden_states, top_k_index, top_k_weights = _random_routing(device) + hidden_states = hidden_states.to(dtype).detach().requires_grad_(True) + + out = module(hidden_states, top_k_index, top_k_weights) + out.float().sum().backward() + + # Gradient reaches the input activations, is finite, and is nonzero (every token is routed + # to TOP_K experts here, so every row contributes). + assert hidden_states.grad is not None + assert torch.isfinite(hidden_states.grad).all() + assert hidden_states.grad.float().abs().sum() > 0 + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_experts4bit_base_weights_stay_frozen(device): + gate_up, down = _random_expert_weights(torch.float32, device) + module = Experts4bit.from_float(gate_up, down, compute_dtype=torch.float32) + + # Packed weights are frozen by construction ... + assert module.gate_up_proj.requires_grad is False + assert module.down_proj.requires_grad is False + + hidden_states, top_k_index, top_k_weights = _random_routing(device) + hidden_states = hidden_states.requires_grad_(True) + module(hidden_states, top_k_index, top_k_weights).sum().backward() + + # ... and a backward pass leaves no gradient on them (so an optimizer can never nudge the + # quantized base, and the absmax buffers are not trainable either). + assert module.gate_up_proj.grad is None + assert module.down_proj.grad is None + + +@pytest.mark.parametrize("device", get_available_devices()) +def test_experts4bit_backward_matches_reference(device): + # float32 throughout: the module's autograd path must match a plain full-precision forward + # built from the *same* dequantized weights, isolating gradient correctness from quant error. + gate_up, down = _random_expert_weights(torch.float32, device) + module = Experts4bit.from_float(gate_up, down, compute_dtype=torch.float32) + + gate_up_deq = torch.stack( + [ + module._dequantize_expert( + module.gate_up_proj, module.gate_up_absmax, module._gate_up_shape, e, torch.float32 + ) + for e in range(NUM_EXPERTS) + ] + ) + down_deq = torch.stack( + [ + module._dequantize_expert(module.down_proj, module.down_absmax, module._down_shape, e, torch.float32) + for e in range(NUM_EXPERTS) + ] + ) + + hidden_states, top_k_index, top_k_weights = _random_routing(device) + hs_mod = hidden_states.detach().clone().requires_grad_(True) + hs_ref = hidden_states.detach().clone().requires_grad_(True) + + out_mod = module(hs_mod, top_k_index, top_k_weights) + out_ref = _reference_forward(gate_up_deq, down_deq, hs_ref, top_k_index, top_k_weights) + + out_mod.sum().backward() + out_ref.sum().backward() + + torch.testing.assert_close(out_mod, out_ref, rtol=1e-4, atol=1e-4) + torch.testing.assert_close(hs_mod.grad, hs_ref.grad, rtol=1e-4, atol=1e-4) + + +def _load_experts_lora(): + """Load the ExpertsLoRA reference wrapper from examples/ (kept out of the bnb API).""" + import importlib.util + import os + + path = os.path.join(os.path.dirname(__file__), "..", "examples", "experts4bit_qlora_demo.py") + spec = importlib.util.spec_from_file_location("experts4bit_qlora_demo", path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module.ExpertsLoRA + + +def test_experts4bit_lora_training_reduces_loss(): + # End-to-end QLoRA-style step: a frozen 4-bit Experts4bit base + trainable per-expert LoRA. + # Proves the primitive supports training today — only the adapters move, the base stays put. + torch.manual_seed(0) + experts_lora = _load_experts_lora() + + gate_up, down = _random_expert_weights(torch.float32, "cpu") + base = Experts4bit.from_float(gate_up, down, compute_dtype=torch.float32) + model = experts_lora(base, r=4, alpha=8) + + # Only LoRA adapters are trainable; the 4-bit base is frozen. + trainable_names = [name for name, p in model.named_parameters() if p.requires_grad] + assert trainable_names and all("lora" in name for name in trainable_names) + + gate_up_before = base.gate_up_proj.clone() + down_before = base.down_proj.clone() + + hidden_states, top_k_index, top_k_weights = _random_routing("cpu") + target = torch.randn_like(hidden_states) + + # Standard LoRA init (B=0) => the adapted forward equals the frozen base forward at step 0. + torch.testing.assert_close( + model(hidden_states, top_k_index, top_k_weights), + base(hidden_states, top_k_index, top_k_weights), + rtol=1e-5, + atol=1e-5, + ) + + optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], lr=1e-2) + losses = [] + for _ in range(30): + optimizer.zero_grad() + loss = torch.nn.functional.mse_loss(model(hidden_states, top_k_index, top_k_weights), target) + loss.backward() + assert base.gate_up_proj.grad is None and base.down_proj.grad is None + optimizer.step() + losses.append(loss.item()) + + assert losses[-1] < losses[0] # training reduces loss + # The frozen 4-bit base is bit-identical before and after training. + assert torch.equal(base.gate_up_proj, gate_up_before) + assert torch.equal(base.down_proj, down_before)