Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,10 @@ elseif(BUILD_HIP)

list(APPEND SRC_FILES ${GPU_FILES})

# 4-bit GEMM: build the SIMT kernel + C dispatch on ROCm.
set_source_files_properties(csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu PROPERTIES LANGUAGE HIP)
list(APPEND SRC_FILES csrc/gemm_4bit_simt.cu csrc/gemm_4bit.cu)

string(APPEND BNB_OUTPUT_NAME "_rocm")

# get hip version
Expand Down
283 changes: 152 additions & 131 deletions bitsandbytes/backends/cuda/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@

from ..._ops import register_kernel
from ...cextension import lib
from ..default.ops import _gemm_4bit_default_impl
from ..utils import _get_4bit_code


def _setup_ctypes(names, argtypes, restype=None):
Expand Down Expand Up @@ -583,7 +581,7 @@ def _gemv_4bit_impl(


@functools.cache
def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
def _gemm_4bit_use_custom_cuda(device_index, dtype, M, N, K):
"""Custom kernel vs dequant+F.linear heuristic for M in [5, 1536].

Per-arch notes (bf16/fp16, M >= 8, large weight):
Expand All @@ -595,6 +593,9 @@ def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
sm100 (B200/B300, HBM3e): exits early at top of function.
sm120 (RTX 5000, GDDR7): dedicated block; medium-N tiers differ from sm89.
"""
if M <= _GEMM_4BIT_CUSTOM_FLOOR_M:
return True

num_sms, major, minor = _gpu_dispatch_props(device_index)
n_blocks = (N + 63) // 64

Expand Down Expand Up @@ -800,140 +801,156 @@ def _gemm_4bit_use_custom(device_index, dtype, M, N, K):
return M <= (16 if (tall_k_2xn or n_blocks < 48) else 8)


if torch.version.hip is None:
@functools.cache
def _gemm_4bit_use_custom_rocm(device_index, dtype, M, N, K):
Comment thread
sstamenk marked this conversation as resolved.
"""
Fused SIMT kernel vs dequant+F.linear heuristic for ROCm.

@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

# M>1536: dequant+F.linear wins (dequant savings negligible at very large batch).
# M<=4: always custom (custom kernel wins universally at small batch).
# M in [5, 1536]: shape/arch-dependent; cached per (device, dtype, M, N, K).
if M > 1536:
use_custom = False
elif K % blocksize != 0:
warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)
use_custom = False
else:
use_custom = M <= 4 or _gemm_4bit_use_custom(A.device.index, A.dtype, M, N, K)

if not use_custom:
if absmax_8bit is not None:
absmax_dq = torch.empty_like(absmax_8bit, dtype=torch.float32)
_dequantize_blockwise_impl(absmax_8bit, absmax, absmax_code, 256, torch.float32, out=absmax_dq)
absmax = absmax_dq + absmax_offset
B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
_dequantize_4bit_impl(B, absmax, blocksize, quant_type, A.dtype, out=B_dq)
return torch.nn.functional.linear(A, B_dq, bias)

if K != shapeB[1]:
raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})")
if absmax.dtype != torch.float32:
raise RuntimeError(f"absmax must be float32, got {absmax.dtype}")
if bias is not None:
if bias.ndim != 1:
raise RuntimeError(f"bias must be 1D, got {bias.ndim}D")
if bias.dtype != A.dtype:
raise RuntimeError(f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})")

quant_type_int = 1 if quant_type == "fp4" else 2

out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
stream = _get_raw_stream(A.device.index)

if A.dtype == torch.bfloat16:
fn = lib.cgemm_4bit_bf16
elif A.dtype == torch.float16:
fn = lib.cgemm_4bit_fp16
elif A.dtype == torch.float32:
fn = lib.cgemm_4bit_fp32
else:
raise RuntimeError(f"unsupported dtype {A.dtype}")

# Offset is expected to be a float32 tensor.
absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None

with _cuda_device_of(A):
fn(
A.data_ptr(),
B.data_ptr(),
absmax.data_ptr(),
absmax_8bit.data_ptr() if absmax_8bit is not None else None,
absmax_code.data_ptr() if absmax_code is not None else None,
absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None,
out.data_ptr(),
bias.data_ptr() if bias is not None else None,
M,
N,
K,
blocksize,
quant_type_int,
stream,
)
RDNA3/RDNA4 calibration keeps the SIMT kernel through ~M=8.
CDNA/gfx9 is calibrated on MI308X (gfx942): bf16/fp16 win through M<=4
after the SIMT math-path tuning, while fp32 only has a broad win through M<=2.

TODO: revisit once WMMA/MFMA kernels land.
"""
if M <= _GEMM_4BIT_CUSTOM_FLOOR_M and dtype != torch.float32:
return True

return out
arch = _rocm_gfx_arch(device_index)
if arch.startswith("gfx11") or arch.startswith("gfx12"): # RDNA3 / RDNA4
return M <= 8
if arch.startswith("gfx9"): # CDNA / MI-series
return M <= (2 if dtype == torch.float32 else 4)
return M <= 4 # unknown ROCm arch: conservative tiny-batch floor


def _rocm_gfx_arch(device_index):
"""gfx arch string (e.g. 'gfx1100') for a ROCm device, feature flags stripped."""
name = getattr(torch.cuda.get_device_properties(device_index), "gcnArchName", "") or ""
return name.split(":")[0]


def _gemm_4bit_kernel_impl(
A, B, shapeB, absmax, blocksize, quant_type, bias=None, absmax_8bit=None, absmax_code=None, absmax_offset=None
):
"""Invoke the fused cgemm_4bit_* kernel (shared by the CUDA and ROCm dispatch; the
C dispatch in gemm_4bit.cu picks SIMT vs MMA per arch/shape). A is made contiguous
because the kernel reads it as row-major (stride K)."""
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

if K != shapeB[1]:
raise RuntimeError(f"A inner dim ({K}) does not match weight ({shapeB[1]})")
if absmax.dtype != torch.float32:
raise RuntimeError(f"absmax must be float32, got {absmax.dtype}")
if bias is not None:
if bias.ndim != 1:
raise RuntimeError(f"bias must be 1D, got {bias.ndim}D")
if bias.dtype != A.dtype:
raise RuntimeError(f"bias dtype ({bias.dtype}) must match A dtype ({A.dtype})")

A = A.contiguous()
quant_type_int = 1 if quant_type == "fp4" else 2
out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
stream = _get_raw_stream(A.device.index)

if A.dtype == torch.bfloat16:
fn = lib.cgemm_4bit_bf16
elif A.dtype == torch.float16:
fn = lib.cgemm_4bit_fp16
elif A.dtype == torch.float32:
fn = lib.cgemm_4bit_fp32
else:
raise RuntimeError(f"unsupported dtype {A.dtype}")

# Offset is expected to be a float32 tensor.
absmax_offset_f32 = absmax_offset.to(dtype=torch.float32) if absmax_offset is not None else None

with _cuda_device_of(A):
fn(
A.data_ptr(),
B.data_ptr(),
absmax.data_ptr(),
absmax_8bit.data_ptr() if absmax_8bit is not None else None,
absmax_code.data_ptr() if absmax_code is not None else None,
absmax_offset_f32.data_ptr() if absmax_offset_f32 is not None else None,
out.data_ptr(),
bias.data_ptr() if bias is not None else None,
M,
N,
K,
blocksize,
quant_type_int,
stream,
)

return out


def _dequant_linear_fallback(
A, B, shapeB, absmax, blocksize, quant_type, bias=None, absmax_8bit=None, absmax_code=None, absmax_offset=None
):
"""Unfused fallback shared by CUDA and ROCm: reconstruct the (optionally nested)
absmax, dequantize the 4-bit weight via the backend dequant impls (reusing
preallocated buffers), then F.linear."""
if absmax_8bit is not None:
absmax_dq = torch.empty_like(absmax_8bit, dtype=torch.float32)
_dequantize_blockwise_impl(absmax_8bit, absmax, absmax_code, 256, torch.float32, out=absmax_dq)
absmax = absmax_dq + absmax_offset
B_dq = torch.empty(shapeB, dtype=A.dtype, device=A.device)
_dequantize_4bit_impl(B, absmax, blocksize, quant_type, A.dtype, out=B_dq)
return torch.nn.functional.linear(A, B_dq, bias)


# Unified CUDA/ROCm dispatch for bitsandbytes::gemm_4bit. The choice *among* custom
# kernels (CUDA SIMT vs MMA; ROCm SIMT) is made in the C dispatch (csrc/gemm_4bit.cu).
_GEMM_4BIT_CUSTOM_FLOOR_M = 4
if torch.version.hip is None:
_gemm_4bit_use_custom_fn = _gemm_4bit_use_custom_cuda
# NVIDIA: dequant+F.linear wins past M=1536 (dequant savings negligible at very
# large batch).
_gemm_4bit_custom_max_m = 1536
else:
_gemm_4bit_use_custom_fn = _gemm_4bit_use_custom_rocm
# ROCm: the custom path is SIMT-only today; the per-arch heuristic above owns
# RDNA/CDNA thresholds. Keep a hard upper cap while WMMA/MFMA paths are absent.
_gemm_4bit_custom_max_m = 256

@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

if M == 1:
if K % blocksize == 0:
if absmax_8bit is not None:
absmax = (
torch.ops.bitsandbytes.dequantize_blockwise.default(
absmax_8bit, absmax, absmax_code, 256, torch.float32
)
+ absmax_offset
)

code = _get_4bit_code(quant_type, A.device)
out = torch.empty((*A.shape[:-1], N), dtype=A.dtype, device=A.device)
_gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)

if bias is not None:
out = out + bias
return out

warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)

return _gemm_4bit_default_impl(
@register_kernel("bitsandbytes::gemm_4bit", "cuda")
def _(
A: torch.Tensor,
B: torch.Tensor,
shapeB: Sequence[int],
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
bias: Optional[torch.Tensor] = None,
absmax_8bit: Optional[torch.Tensor] = None,
absmax_code: Optional[torch.Tensor] = None,
absmax_offset: Optional[torch.Tensor] = None,
) -> torch.Tensor:
K = A.shape[-1]
M = A.numel() // K
N = shapeB[0]

# The backend-specific heuristic owns tiny-M floors and per-arch thresholds.
# Past custom_max_m (or for blocksize-misaligned K), use the dequant+F.linear
# fallback.
if M > _gemm_4bit_custom_max_m:
use_custom = False
elif K % blocksize != 0:
warn(
f"inner dimension ({K}) is not aligned for fast kernel "
f"with blocksize={blocksize}, falling back to slower implementation.",
UserWarning,
)
use_custom = False
else:
use_custom = _gemm_4bit_use_custom_fn(A.device.index, A.dtype, M, N, K)

if not use_custom:
return _dequant_linear_fallback(
A,
B,
shapeB,
Expand All @@ -946,6 +963,10 @@ def _(
absmax_offset=absmax_offset,
)

return _gemm_4bit_kernel_impl(
A, B, shapeB, absmax, blocksize, quant_type, bias, absmax_8bit, absmax_code, absmax_offset
)


"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
Expand Down
13 changes: 12 additions & 1 deletion csrc/compat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

#if BNB_HIP

#include <hip/hip_bf16.h>
#include <hip/hip_bfloat16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_math_constants.h>
Expand Down Expand Up @@ -89,12 +90,22 @@ using bnb_error_t = cudaError_t;
} while (0)
#endif

// Full-warp mask for __shfl_*_sync. ROCm 7+ statically requires a 64-bit mask
// type; CUDA uses the conventional 32-bit unsigned mask.
#if BNB_HIP
#define BNB_FULL_WARP_MASK 0xffffffffull
#else
#define BNB_FULL_WARP_MASK 0xffffffffu
#endif

// BFloat16 type alias

#if BNB_HIP
using bnb_bfloat16 = hip_bfloat16;
using bnb_bfloat16 = __hip_bfloat16;
using bnb_bfloat162 = __hip_bfloat162;
#else
using bnb_bfloat16 = __nv_bfloat16;
using bnb_bfloat162 = __nv_bfloat162;
#endif

// Data type enum aliases for BLAS libraries
Expand Down
Loading
Loading