From 94a2786bd9562aafe51d13f876965d3c90a82862 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:33:34 +0200 Subject: [PATCH 01/22] init --- src/transformers/integrations/deepgemm.py | 343 ++++++++++++++++++ .../integrations/finegrained_fp8.py | 230 +----------- src/transformers/integrations/moe.py | 2 + 3 files changed, 348 insertions(+), 227 deletions(-) create mode 100644 src/transformers/integrations/deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py new file mode 100644 index 000000000000..7a8fb0786446 --- /dev/null +++ b/src/transformers/integrations/deepgemm.py @@ -0,0 +1,343 @@ +# Copyright 2026 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. + +Provides: +- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. +- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. +- `bf16_deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. + +Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. +""" + +import functools + +import torch + +from ..utils import logging +from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from .hub_kernels import lazy_load_kernel + + +logger = logging.get_logger(__name__) + +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + +@functools.cache +def _load_deepgemm_kernel(): + """ + Load deep-gemm once and return its required symbols. + + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. + + Returns: + Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8) from the deep-gemm kernel. + """ + if not torch.cuda.is_available(): + raise ImportError( + "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + + # deep-gemm requires CUDA runtime >= 12.3 + cuda_major, cuda_minor = get_cuda_runtime_version() + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): + raise ImportError( + f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + "Please upgrade your CUDA toolkit or use a different `experts_implementation`." + ) + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "deep-gemm kernel not found. Make sure you have the `kernels` package installed (`pip install -U kernels`)." + ) + + fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) + m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + + missing = [ + name + for name, attr in [ + ("fp8_gemm_nt", fp8_gemm_nt), + ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), + ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ] + if attr is None + ] + if missing: + raise ImportError( + f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + "Please update the `kernels` package (`pip install -U kernels`)." + ) + + return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8 + + +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + fp8_gemm_nt, _, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: + """Build a TMA-aligned contiguous layout for deep-gemm's grouped GEMM. + + deep-gemm requires M-dimension alignment per expert for TMA. This computes + the mapping from sorted token positions to padded row positions, and the + layout tensor that deep-gemm uses to identify expert boundaries. + + Returns: + sorted_to_padded: (num_tokens,) index map from sorted position to padded row + grouped_layout: expert layout tensor (format depends on GPU architecture) + total_padded_rows: total number of rows including alignment padding + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) + grouped_layout = tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = expert_ids_sorted.int() + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout.""" + padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + +def fp8_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "deepgemm experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) + if self.block_size is None: + raise ValueError( + "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + + _, m_grouped_fp8_gemm_nt_contiguous, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + ) + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # --- Up projection per expert (deep-gemm grouped contiguous) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous) --- + proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) + + # Remove padding rows + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + +def bf16_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.is_transposed: + raise ValueError("deepgemm bf16 path requires non-transposed weights (is_transposed=False)") + if not self.has_gate: + raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") + if self.has_bias: + raise ValueError("deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)") + if hidden_states.device.type != "cuda": + raise ValueError("deepgemm bf16 path requires CUDA device") + + _, _, m_grouped_bf16_gemm_nt_contiguous, _ = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # Handle invalid expert IDs from Expert Parallelism (EP) + invalid_mask = expert_ids >= self.num_experts + expert_ids = expert_ids.clamp(0, self.num_experts - 1) + + # Sort by expert for grouped processing + perm = torch.argsort(expert_ids) + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + + expert_ids_g = expert_ids[perm] + sample_weights_g = sample_weights[perm] + invalid_mask_g = invalid_mask[perm] + selected_hidden_states_g = hidden_states[token_idx[perm]] + + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + ) + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- + act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + proj_out = torch.zeros( + total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype + ) + m_grouped_bf16_gemm_nt_contiguous( + act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating + proj_out = self._apply_gate(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- + out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) + m_grouped_bf16_gemm_nt_contiguous( + proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Remove padding rows + out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) + + # Apply routing weights and zero out invalid expert contributions + weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) + + # Restore original order + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index a6b9a517b20d..5f583533792e 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -19,7 +19,7 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -31,11 +31,6 @@ _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory -_DEEPGEMM_M_ALIGNMENT = 128 - # Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel) triton_fp8_matmul = None triton_fp8_act_quant = None @@ -44,13 +39,6 @@ # _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry) _triton_available = None -# Lazily-loaded DeepGEMM kernel functions (populated by _load_deepgemm_kernel) -deepgemm_fp8_matmul = None -deepgemm_grouped_fp8_matmul = None -deepgemm_per_token_cast_to_fp8 = None -# _deepgemm_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_deepgemm_available = None - def _load_triton_kernel(): """Lazily load the finegrained-fp8 Triton kernel and extract functions. @@ -97,67 +85,6 @@ def _load_triton_kernel(): _triton_available = True -def _load_deepgemm_kernel(): - """Lazily load the DeepGEMM kernel and extract functions with proper names. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded, required functions are missing, or the hardware is insufficient. - Only attempts loading once. - """ - global _deepgemm_available, deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 - - if _deepgemm_available is not None: - if not _deepgemm_available: - raise ImportError("DeepGEMM kernel is not available (previous load attempt failed).") - return - - _deepgemm_available = False # mark attempted before any early exit - - # DeepGEMM requires CUDA and a compatible GPU - if not torch.cuda.is_available(): - raise ImportError( - "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # DeepGEMM requires CUDA runtime ≥ 12.3. - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt") - deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous") - deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", deepgemm_fp8_matmul), - ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), - ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"DeepGEMM kernel is missing required functions: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) - - _deepgemm_available = True - - def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -191,21 +118,14 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - _load_deepgemm_kernel() - global deepgemm_fp8_matmul + # 3-6x faster than Triton + return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - else: - # 3-6x faster than Triton - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) _load_triton_kernel() global triton_fp8_matmul @@ -434,150 +354,6 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: - """Build a TMA-aligned contiguous layout for DeepGEMM grouped GEMM. - - DeepGEMM requires M-dimension alignment per expert for TMA. This computes - the mapping from sorted token positions to padded row positions, and the - layout tensor that DeepGEMM uses to identify expert boundaries. - - Returns: - sorted_to_padded: (num_tokens,) index map from sorted position to padded row - grouped_layout: expert layout tensor (format depends on GPU architecture) - total_padded_rows: total number of rows including alignment padding - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU→CPU sync; padding rows are skipped by DeepGEMM. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) - grouped_layout = tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = expert_ids_sorted.int() - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_to_deepgemm_contiguous_layout( - hidden_states: torch.Tensor, - scales: torch.Tensor, - sorted_to_padded: torch.Tensor, - total_padded_rows: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" - hidden_padded = torch.zeros( - total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype - ) - hidden_padded[sorted_to_padded] = hidden_states - scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) - scales_padded[sorted_to_padded] = scales - return hidden_padded, scales_padded - - -def _unpad_from_deepgemm_contiguous_layout( - hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor -) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return hidden_states_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") - - _load_deepgemm_kernel() - global deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] - sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] - - # Build TMA-aligned contiguous layout for DeepGEMM - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT - ) - - # --- Up projection per expert (DeepGEMM grouped contiguous) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - deepgemm_grouped_fp8_matmul( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (DeepGEMM grouped contiguous) --- - w_down = self.down_proj - ws_down = self.down_proj_scale_inv - proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - deepgemm_grouped_fp8_matmul( - (proj_fp8, proj_scales), (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - class FP8Experts(nn.Module): def __init__( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index d17522d26daa..622b0ceb2fa6 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -23,6 +23,7 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) +from .deepgemm import bf16_deepgemm_experts_forward if is_torch_available(): @@ -460,6 +461,7 @@ class ExpertsInterface(GeneralInterface): _global_mapping = { "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, + "deepgemm": bf16_deepgemm_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: From 357a0355c9f6f6a9df20c85a163d5711b5635a76 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:34:29 +0200 Subject: [PATCH 02/22] style --- src/transformers/integrations/deepgemm.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 7a8fb0786446..f2951deda99c 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -273,7 +273,9 @@ def bf16_deepgemm_experts_forward( if not self.has_gate: raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") if self.has_bias: - raise ValueError("deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)") + raise ValueError( + "deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)" + ) if hidden_states.device.type != "cuda": raise ValueError("deepgemm bf16 path requires CUDA device") @@ -310,9 +312,7 @@ def bf16_deepgemm_experts_forward( # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros( - total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype - ) + proj_out = torch.zeros(total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm_nt_contiguous( act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -322,9 +322,7 @@ def bf16_deepgemm_experts_forward( # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous( - proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout - ) + m_grouped_bf16_gemm_nt_contiguous(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) From 741b5eb717829ba7cfba22bc823fc48e73381b40 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Thu, 23 Apr 2026 15:44:38 +0200 Subject: [PATCH 03/22] full support --- src/transformers/integrations/deepgemm.py | 71 +++++++++++++++-------- src/transformers/integrations/sonicmoe.py | 19 +++++- 2 files changed, 64 insertions(+), 26 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index f2951deda99c..98c2b83032e2 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -27,7 +27,7 @@ import torch from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import from .hub_kernels import lazy_load_kernel @@ -50,8 +50,12 @@ def _load_deepgemm_kernel(): Returns: Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8) from the deep-gemm kernel. + m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, + per_token_cast_to_fp8) from the deep-gemm kernel. """ + if not is_kernels_available(): + raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + if not torch.cuda.is_available(): raise ImportError( "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." @@ -82,6 +86,7 @@ def _load_deepgemm_kernel(): fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) + m_grouped_bf16_gemm_nn_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") missing = [ @@ -90,6 +95,7 @@ def _load_deepgemm_kernel(): ("fp8_gemm_nt", fp8_gemm_nt), ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), + ("m_grouped_bf16_gemm_nn_contiguous", m_grouped_bf16_gemm_nn_contiguous), ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), ] if attr is None @@ -100,7 +106,13 @@ def _load_deepgemm_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, m_grouped_bf16_gemm_nt_contiguous, per_token_cast_to_fp8 + return ( + fp8_gemm_nt, + m_grouped_fp8_gemm_nt_contiguous, + m_grouped_bf16_gemm_nt_contiguous, + m_grouped_bf16_gemm_nn_contiguous, + per_token_cast_to_fp8, + ) def fp8_deepgemm_matmul( @@ -120,7 +132,7 @@ def fp8_deepgemm_matmul( Bs: (N//128, K//128) float32 — per-block weight scales output_dtype: desired output dtype. """ - fp8_gemm_nt, _, _, _ = _load_deepgemm_kernel() + fp8_gemm_nt, _, _, _, _ = _load_deepgemm_kernel() A_2d = A.view(-1, A.shape[-1]) As_2d = As.view(-1, As.shape[-1]) output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) @@ -192,7 +204,7 @@ def fp8_deepgemm_experts_forward( if self.block_size[0] != 128 or self.block_size[1] != 128: raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - _, m_grouped_fp8_gemm_nt_contiguous, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() + _, m_grouped_fp8_gemm_nt_contiguous, _, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -268,18 +280,15 @@ def bf16_deepgemm_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - if self.is_transposed: - raise ValueError("deepgemm bf16 path requires non-transposed weights (is_transposed=False)") - if not self.has_gate: - raise ValueError("deepgemm bf16 path requires gated experts (has_gate=True)") - if self.has_bias: - raise ValueError( - "deepgemm bf16 path does not support bias (m_grouped_bf16_gemm_nt_contiguous has no bias input)" - ) - if hidden_states.device.type != "cuda": - raise ValueError("deepgemm bf16 path requires CUDA device") - - _, _, m_grouped_bf16_gemm_nt_contiguous, _ = _load_deepgemm_kernel() + if hidden_states.dtype != torch.bfloat16: + raise ValueError(f"deepgemm bf16 path requires bfloat16 hidden states, got {hidden_states.dtype}") + + _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() + # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. + # Transposed HF experts have weight layout (E, K, N) -> NN kernel. + m_grouped_bf16_gemm = ( + m_grouped_bf16_gemm_nn_contiguous if self.is_transposed else m_grouped_bf16_gemm_nt_contiguous + ) device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -301,8 +310,8 @@ def bf16_deepgemm_experts_forward( inv_perm[perm] = torch.arange(perm.size(0), device=device) expert_ids_g = expert_ids[perm] - sample_weights_g = sample_weights[perm] invalid_mask_g = invalid_mask[perm] + sample_weights_g = sample_weights[perm] selected_hidden_states_g = hidden_states[token_idx[perm]] sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( @@ -311,18 +320,30 @@ def bf16_deepgemm_experts_forward( use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). + up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, self.gate_up_proj.shape[1], device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous( - act, self.gate_up_proj, proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) + proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) + + # The kernel has no bias input -> add per-expert bias post-GEMM; padding rows get discarded at unpad time. + if self.has_bias: + up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias + proj_out = proj_out + _pad_for_deepgemm(up_bias[expert_ids_g], sorted_to_padded, total_padded_rows) - # Apply gating - proj_out = self._apply_gate(proj_out) + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm_nt_contiguous(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) + + if self.has_bias: + out = out + _pad_for_deepgemm(self.down_proj_bias[expert_ids_g], sorted_to_padded, total_padded_rows) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index e322bb4bc061..df6bfbbd8f1a 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -23,6 +23,7 @@ import torch from ..utils import logging +from ..utils.import_utils import is_kernels_available from .hub_kernels import lazy_load_kernel @@ -38,11 +39,27 @@ def _load_sonic_kernel(): Load sonic-moe once and return its required symbols. Raises: - ImportError if the kernel or required symbols are not found. + ImportError if CUDA/hardware requirements are not met, or if the kernel or + required symbols are not found. Returns: Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. """ + if not is_kernels_available(): + raise ImportError("sonic-moe kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + + if not torch.cuda.is_available(): + raise ImportError( + "sonic-moe kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # sonic-moe requires Hopper (SM90) or newer + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"sonic-moe requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) kernel = lazy_load_kernel("sonic-moe") if kernel is None: From 9fc3662d1f9ec22ff94615513b1d5c189772a5ef Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 09:45:11 +0200 Subject: [PATCH 04/22] support EP better using offsets ! --- src/transformers/integrations/deepgemm.py | 118 ++++++++++-------- .../integrations/finegrained_fp8.py | 104 ++++++++------- src/transformers/integrations/moe.py | 51 +++----- src/transformers/integrations/sonicmoe.py | 5 +- .../integrations/tensor_parallel.py | 14 +++ 5 files changed, 154 insertions(+), 138 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 98c2b83032e2..4d5fcf8095b4 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -17,11 +17,13 @@ Provides: - `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. - `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `bf16_deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. +- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. """ +from __future__ import annotations + import functools import torch @@ -80,7 +82,8 @@ def _load_deepgemm_kernel(): kernel = lazy_load_kernel("deep-gemm") if kernel is None: raise ImportError( - "deep-gemm kernel not found. Make sure you have the `kernels` package installed (`pip install -U kernels`)." + "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." ) fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) @@ -140,42 +143,56 @@ def fp8_deepgemm_matmul( return output.view(A.shape[:-1] + (B.shape[0],)) -def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: - """Build a TMA-aligned contiguous layout for deep-gemm's grouped GEMM. +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - deep-gemm requires M-dimension alignment per expert for TMA. This computes - the mapping from sorted token positions to padded row positions, and the - layout tensor that deep-gemm uses to identify expert boundaries. + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. - Returns: - sorted_to_padded: (num_tokens,) index map from sorted position to padded row - grouped_layout: expert layout tensor (format depends on GPU architecture) - total_padded_rows: total number of rows including alignment padding + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so deep-gemm skips them. """ device = expert_ids_sorted.device num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) - grouped_layout = tokens_per_expert.cumsum(0).int() + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() else: - # Hopper: per-row expert id, -1 for padding rows + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = expert_ids_sorted.int() + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) return sorted_to_padded, grouped_layout, total_padded_rows def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout.""" - padded = torch.zeros(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) padded[sorted_to_padded] = x return padded @@ -212,23 +229,18 @@ def fp8_deepgemm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 # --- Up projection per expert (deep-gemm grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj @@ -236,7 +248,7 @@ def fp8_deepgemm_experts_forward( act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) m_grouped_fp8_gemm_nt_contiguous( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -249,7 +261,7 @@ def fp8_deepgemm_experts_forward( # --- Down projection per expert (deep-gemm grouped contiguous) --- proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) m_grouped_fp8_gemm_nt_contiguous( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), @@ -262,9 +274,11 @@ def fp8_deepgemm_experts_forward( proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -274,14 +288,14 @@ def fp8_deepgemm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def bf16_deepgemm_experts_forward( +def deepgemm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: if hidden_states.dtype != torch.bfloat16: - raise ValueError(f"deepgemm bf16 path requires bfloat16 hidden states, got {hidden_states.dtype}") + raise ValueError(f"deepgemm path requires bfloat16 hidden states, got {hidden_states.dtype}") _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. @@ -296,41 +310,40 @@ def bf16_deepgemm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail + # and `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond + # the cumsum on Blackwell) — so deep-gemm performs no real GEMM work for them. # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] - invalid_mask_g = invalid_mask[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) + # `torch.zeros` so sentinel rows read back as 0 at unpad time (kernel leaves them untouched). proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - # The kernel has no bias input -> add per-expert bias post-GEMM; padding rows get discarded at unpad time. + # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; + # padding rows get discarded at unpad time. if self.has_bias: up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias - proj_out = proj_out + _pad_for_deepgemm(up_bias[expert_ids_g], sorted_to_padded, total_padded_rows) + proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) # Apply gating or activation if self.has_gate: @@ -343,16 +356,17 @@ def bf16_deepgemm_experts_forward( m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) if self.has_bias: - out = out + _pad_for_deepgemm(self.down_proj_bias[expert_ids_g], sorted_to_padded, total_padded_rows) + out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) # Remove padding rows out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) - # Apply routing weights and zero out invalid expert contributions - weighted_out = out * sample_weights_g.to(out.dtype).unsqueeze(-1) # (S, hidden_dim) - weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) + # Apply routing weights + weighted_out = out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 5f583533792e..9579d50c5fd7 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import functools + import torch import torch.nn as nn from torch.nn import functional as F @@ -19,9 +23,11 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging +from ..utils.import_utils import is_kernels_available from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation +from .tensor_parallel import neutralize_ep_sentinels logger = logging.get_logger(__name__) @@ -31,40 +37,36 @@ _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel) -triton_fp8_matmul = None -triton_fp8_act_quant = None -triton_batched_fp8_matmul = None -triton_grouped_fp8_matmul = None -# _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_triton_available = None - +@functools.cache def _load_triton_kernel(): - """Lazily load the finegrained-fp8 Triton kernel and extract functions. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded or required functions are missing. Only attempts loading once. """ - global \ - _triton_available, \ - triton_fp8_act_quant, \ - triton_fp8_matmul, \ - triton_batched_fp8_matmul, \ - triton_grouped_fp8_matmul + Load the finegrained-fp8 Triton kernel once and return its required symbols. - if _triton_available is not None: - if not _triton_available: - raise ImportError("finegrained-fp8 kernel is not available (previous load attempt failed).") - return + Raises: + ImportError if the `kernels` package is missing, or the kernel or required + symbols cannot be found. - _triton_available = False # mark attempted before any early exit + Returns: + Tuple of (w8a8_fp8_matmul, fp8_act_quant, w8a8_fp8_matmul_batched, + w8a8_fp8_matmul_grouped) from the finegrained-fp8 kernel. + """ + if not is_kernels_available(): + raise ImportError( + "finegrained-fp8 kernel requires the `kernels` package. Install it with `pip install -U kernels`." + ) kernel = lazy_load_kernel("finegrained-fp8") - triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul") - triton_fp8_act_quant = getattr(kernel, "fp8_act_quant") - triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched") - triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped") + if kernel is None: + raise ImportError( + "Failed to load the finegrained-fp8 kernel — check that `kernels-community/finegrained-fp8` " + "has a build matching the current torch/CUDA." + ) + + triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul", None) + triton_fp8_act_quant = getattr(kernel, "fp8_act_quant", None) + triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) + triton_grouped_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_grouped", None) missing = [ name @@ -78,11 +80,11 @@ def _load_triton_kernel(): ] if missing: raise ImportError( - f"finegrained-fp8 kernel is missing required functions: {', '.join(missing)}. " + f"finegrained-fp8 kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) - _triton_available = True + return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul def _cdiv(a: int, b: int) -> int: @@ -127,8 +129,7 @@ def w8a8_fp8_matmul( "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) - _load_triton_kernel() - global triton_fp8_matmul + triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -182,8 +183,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: scale_inv = self.weight_scale_inv.contiguous() if self.activation_scheme == "dynamic": - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -203,7 +203,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) if self.bias is not None: - output = output + self.bias + output.add_(self.bias) return output.to(dtype=input.dtype) @@ -220,21 +220,20 @@ def fp8_batched_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_batched_fp8_matmul + _, _, triton_batched_fp8_matmul, _ = _load_triton_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + # Handle invalid expert IDs from Expert Parallelism (EP) + neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) # --- Up projection per expert (FP8 batched) --- proj_out = triton_batched_fp8_matmul( @@ -263,7 +262,8 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + # Let torch promote bf16 `proj_out` × fp32 `sample_weights` to fp32 for the reduction below. + weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) @@ -284,8 +284,7 @@ def fp8_grouped_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_grouped_fp8_matmul + _, _, _, triton_grouped_fp8_matmul = _load_triton_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -293,22 +292,18 @@ def fp8_grouped_mm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped processing. # histc instead of bincount avoids cuda-graph issues; # CPU requires float input, CUDA requires int input (deterministic mode). + # histc drops values > max, so sentinels (== num_experts) are excluded from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) @@ -342,9 +337,11 @@ def fp8_grouped_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -472,8 +469,7 @@ def linear( scale = activation_scale.to(torch.float32) qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) else: - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -685,5 +681,5 @@ def convert( } @property - def reverse_op(self) -> "ConversionOps": + def reverse_op(self) -> ConversionOps: return _IdentityOp() diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index b8015a0505b4..2c3ea91eafb6 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from collections.abc import Callable from functools import wraps @@ -23,8 +24,9 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) -from .deepgemm import bf16_deepgemm_experts_forward +from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward +from .tensor_parallel import neutralize_ep_sentinels if is_torch_available(): @@ -103,7 +105,7 @@ def _batched_linear( out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) if bias is not None: - out = out + bias + out.add_(bias) return out @@ -114,24 +116,18 @@ def batched_mm_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) # Handle invalid expert IDs from Expert Parallelism (EP) - # When EP is enabled, tokens assigned to experts on other devices are marked with sentinel value >= num_experts - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) # Select gate_up or just up projection weights and biases if self.has_gate: @@ -163,9 +159,8 @@ def batched_mm_experts_forward( proj_out, selected_weights, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions + # Apply routing weights weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) - weighted_out.masked_fill_(invalid_mask.unsqueeze(-1), 0.0) # Zero out invalid expert contributions # Accumulate results using deterministic reshape+sum instead of index_add_ # index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd @@ -364,7 +359,7 @@ def _grouped_linear( if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported. - out = out + bias + out.add_(bias) return out @@ -380,32 +375,26 @@ def grouped_mm_experts_forward( num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # With deterministic algorithms, CPU only supports float input, CUDA only supports int input. + # `max=num_experts-1` drops unclamped sentinels (value == num_experts) from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + # Clamp now that offsets are built. We only need this for the per-row bias gather below to stay in-bounds. + expert_ids_g.clamp_(0, self.num_experts - 1) + # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but @@ -440,12 +429,12 @@ def grouped_mm_experts_forward( proj_out, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions from EP + # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - invalid_mask_g = invalid_mask[perm] - weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -461,10 +450,10 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { - "sonicmoe": sonicmoe_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, - "deepgemm": bf16_deepgemm_experts_forward, + "deepgemm": deepgemm_experts_forward, + "sonicmoe": sonicmoe_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index df6bfbbd8f1a..d6eee485fea7 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -18,6 +18,8 @@ Requirements: CUDA, `kernels`, `nvidia-cutlass-dsl`, has_gate=True. """ +from __future__ import annotations + import functools import torch @@ -64,7 +66,8 @@ def _load_sonic_kernel(): kernel = lazy_load_kernel("sonic-moe") if kernel is None: raise ImportError( - "sonic-moe kernel not found. Make sure you have the `kernels` and `nvidia-cutlass-dsl` packages installed." + "Failed to load the sonic-moe kernel — check that `kernels-community/sonic-moe` " + "has a build matching the current torch/CUDA." ) ActivationType = getattr(getattr(kernel, "enums", None), "ActivationType", None) diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 82d6d284f052..0c4557e4d3d7 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1079,6 +1079,20 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] +def neutralize_ep_sentinels(expert_ids, sample_weights, num_experts) -> None: + """Make EP sentinel slots (`expert_ids >= num_experts`) no-ops for indexing backends. + + Mutates in place: clamps `expert_ids` in-range (so weight indexing stays valid) and zeros + `sample_weights` at sentinel slots (so their expert GEMM output contributes nothing). + + Sentinel tokens still go through the expert GEMMs; filtering them beforehand needs a host sync + or dynamic-shape kernels, both of which break CUDA graphs — so we keep the shape-preserving path. + Grouped-GEMM backends can skip sentinels via offsets instead — see `grouped_mm_experts_forward`. + """ + sample_weights.masked_fill_(expert_ids >= num_experts, 0.0) + expert_ids.clamp_(0, num_experts - 1) + + class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. From 84552ae98465ad2ed13bbbc67ff08b79b9bcb1bd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 10:48:41 +0200 Subject: [PATCH 05/22] comments --- src/transformers/integrations/deepgemm.py | 19 +++++++++++-------- .../integrations/finegrained_fp8.py | 5 ++++- src/transformers/integrations/moe.py | 12 +++++++++--- 3 files changed, 24 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py index 4d5fcf8095b4..10fb7adcda8e 100644 --- a/src/transformers/integrations/deepgemm.py +++ b/src/transformers/integrations/deepgemm.py @@ -313,9 +313,11 @@ def deepgemm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail - # and `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond - # the cumsum on Blackwell) — so deep-gemm performs no real GEMM work for them. + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul + # contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -326,17 +328,17 @@ def deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - # `torch.zeros` so sentinel rows read back as 0 at unpad time (kernel leaves them untouched). - proj_out = torch.zeros(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) + proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; @@ -352,6 +354,7 @@ def deepgemm_experts_forward( proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- + # Zero-init: unpad later reads sentinel-row positions the kernel never writes. out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 9579d50c5fd7..e8d3f25c3edc 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -295,6 +295,10 @@ def fp8_grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips + # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are + # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -303,7 +307,6 @@ def fp8_grouped_mm_experts_forward( # Compute offsets for grouped processing. # histc instead of bincount avoids cuda-graph issues; # CPU requires float input, CUDA requires int input (deterministic mode). - # histc drops values > max, so sentinels (== num_experts) are excluded from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 2c3ea91eafb6..705e07763bd4 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -379,6 +379,10 @@ def grouped_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows + # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are + # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -387,19 +391,21 @@ def grouped_mm_experts_forward( # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues # With deterministic algorithms, CPU only supports float input, CUDA only supports int input. - # `max=num_experts-1` drops unclamped sentinels (value == num_experts) from the per-expert count. histc_input = expert_ids_g.float() if device.type == "cpu" else expert_ids_g.int() tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) - # Clamp now that offsets are built. We only need this for the per-row bias gather below to stay in-bounds. - expert_ids_g.clamp_(0, self.num_experts - 1) + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but # to do so I had to use torch.unique which breaks the graph capture (data-dependent). # Also there were no speedup gains from it in my experiments, even in eager mode. + # NOTE: The grouped_mm kernel only targets the active experts / tokens via the offsets if self.has_gate: selected_weights = self.gate_up_proj selected_biases = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None From 1d9f319b9623d414ca8e8b7b931c3081efc27100 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:00:48 +0200 Subject: [PATCH 06/22] get rid of neutralize_ep_sentinels --- src/transformers/integrations/finegrained_fp8.py | 7 ++++--- src/transformers/integrations/moe.py | 7 ++++--- src/transformers/integrations/tensor_parallel.py | 14 -------------- 3 files changed, 8 insertions(+), 20 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index e8d3f25c3edc..f08329003df4 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -27,7 +27,6 @@ from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation -from .tensor_parallel import neutralize_ep_sentinels logger = logging.get_logger(__name__) @@ -232,8 +231,10 @@ def fp8_batched_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) + # Clamp EP sentinels so per-token weight indexing stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # --- Up projection per expert (FP8 batched) --- proj_out = triton_batched_fp8_matmul( diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 705e07763bd4..4f1f9c315959 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -26,7 +26,6 @@ ) from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward -from .tensor_parallel import neutralize_ep_sentinels if is_torch_available(): @@ -126,8 +125,10 @@ def batched_mm_experts_forward( sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - neutralize_ep_sentinels(expert_ids, sample_weights, self.num_experts) + # Clamp EP sentinels so `gate_up_proj[expert_ids]` stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # Select gate_up or just up projection weights and biases if self.has_gate: diff --git a/src/transformers/integrations/tensor_parallel.py b/src/transformers/integrations/tensor_parallel.py index 0c4557e4d3d7..82d6d284f052 100644 --- a/src/transformers/integrations/tensor_parallel.py +++ b/src/transformers/integrations/tensor_parallel.py @@ -1079,20 +1079,6 @@ def update_module_attributes(self, module: nn.Module): module.num_experts = self.get_expected_sharded_shape((self.empty_param.shape[0],))[0] -def neutralize_ep_sentinels(expert_ids, sample_weights, num_experts) -> None: - """Make EP sentinel slots (`expert_ids >= num_experts`) no-ops for indexing backends. - - Mutates in place: clamps `expert_ids` in-range (so weight indexing stays valid) and zeros - `sample_weights` at sentinel slots (so their expert GEMM output contributes nothing). - - Sentinel tokens still go through the expert GEMMs; filtering them beforehand needs a host sync - or dynamic-shape kernels, both of which break CUDA graphs — so we keep the shape-preserving path. - Grouped-GEMM backends can skip sentinels via offsets instead — see `grouped_mm_experts_forward`. - """ - sample_weights.masked_fill_(expert_ids >= num_experts, 0.0) - expert_ids.clamp_(0, num_experts - 1) - - class RouterParallel(TensorParallelLayer): """ Allows to reshape the router scores to support running expert parallel. From 9b8604341198535dd11016bbf35100139dd9a2bd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:24:05 +0200 Subject: [PATCH 07/22] remove deepgemm stuff --- src/transformers/integrations/deepgemm.py | 379 ------------------ .../integrations/finegrained_fp8.py | 251 +++++++++++- src/transformers/integrations/moe.py | 2 - 3 files changed, 249 insertions(+), 383 deletions(-) delete mode 100644 src/transformers/integrations/deepgemm.py diff --git a/src/transformers/integrations/deepgemm.py b/src/transformers/integrations/deepgemm.py deleted file mode 100644 index 10fb7adcda8e..000000000000 --- a/src/transformers/integrations/deepgemm.py +++ /dev/null @@ -1,379 +0,0 @@ -# Copyright 2026 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -"""DeepGEMM integration: fused grouped GEMM kernels from `kernels-community/deep-gemm`. - -Provides: -- `fp8_deepgemm_matmul`: FP8 dense matmul used as a fast path inside the finegrained-fp8 Linear. -- `fp8_deepgemm_experts_forward`: FP8 M-grouped experts forward, registered as "deepgemm" in the FP8 ExpertsInterface. -- `deepgemm_experts_forward`: BF16 M-grouped experts forward, registered as "deepgemm" in the ExpertsInterface. - -Requirements: CUDA, Hopper (SM90+), CUDA runtime >= 12.3, `kernels`. -""" - -from __future__ import annotations - -import functools - -import torch - -from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import -from .hub_kernels import lazy_load_kernel - - -logger = logging.get_logger(__name__) - -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 - - -@functools.cache -def _load_deepgemm_kernel(): - """ - Load deep-gemm once and return its required symbols. - - Raises: - ImportError if CUDA/hardware requirements are not met, or the kernel or - required symbols are not found. - - Returns: - Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, - per_token_cast_to_fp8) from the deep-gemm kernel. - """ - if not is_kernels_available(): - raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - - if not torch.cuda.is_available(): - raise ImportError( - "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." - ) - - # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions - major = torch.cuda.get_device_capability()[0] - if major < 9: - raise ImportError( - f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " - f"has compute capability {major}.x. Use a different `experts_implementation`." - ) - - # deep-gemm requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() - if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): - raise ImportError( - f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " - "Please upgrade your CUDA toolkit or use a different `experts_implementation`." - ) - - kernel = lazy_load_kernel("deep-gemm") - if kernel is None: - raise ImportError( - "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " - "has a build matching the current torch/CUDA." - ) - - fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) - m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - m_grouped_bf16_gemm_nt_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nt_contiguous", None) - m_grouped_bf16_gemm_nn_contiguous = getattr(kernel, "m_grouped_bf16_gemm_nn_contiguous", None) - per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") - - missing = [ - name - for name, attr in [ - ("fp8_gemm_nt", fp8_gemm_nt), - ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), - ("m_grouped_bf16_gemm_nt_contiguous", m_grouped_bf16_gemm_nt_contiguous), - ("m_grouped_bf16_gemm_nn_contiguous", m_grouped_bf16_gemm_nn_contiguous), - ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), - ] - if attr is None - ] - if missing: - raise ImportError( - f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " - "Please update the `kernels` package (`pip install -U kernels`)." - ) - - return ( - fp8_gemm_nt, - m_grouped_fp8_gemm_nt_contiguous, - m_grouped_bf16_gemm_nt_contiguous, - m_grouped_bf16_gemm_nn_contiguous, - per_token_cast_to_fp8, - ) - - -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - fp8_gemm_nt, _, _, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so deep-gemm skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded - - -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] - - -def fp8_deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if self.activation_scheme == "static": - raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " - "Use the default eager dispatch or switch to activation_scheme='dynamic'." - ) - if self.block_size is None: - raise ValueError( - "deep-gemm requires block-wise quantization (block_size=[128, 128]), " - "but got per-tensor quantization (block_size=None)." - ) - if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - - _, m_grouped_fp8_gemm_nt_contiguous, _, _, per_token_cast_to_fp8 = _load_deepgemm_kernel() - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - # --- Up projection per expert (deep-gemm grouped contiguous) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) - act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( - (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout - ) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (deep-gemm grouped contiguous) --- - proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( - (proj_fp8, proj_scales), - (self.down_proj, self.down_proj_scale_inv.float()), - proj_out, - grouped_layout, - use_psum_layout=use_psum_layout, - ) - - # Remove padding rows - proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) - - # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) - - -def deepgemm_experts_forward( - self: torch.nn.Module, - hidden_states: torch.Tensor, - top_k_index: torch.Tensor, - top_k_weights: torch.Tensor, -) -> torch.Tensor: - if hidden_states.dtype != torch.bfloat16: - raise ValueError(f"deepgemm path requires bfloat16 hidden states, got {hidden_states.dtype}") - - _, _, m_grouped_bf16_gemm_nt_contiguous, m_grouped_bf16_gemm_nn_contiguous, _ = _load_deepgemm_kernel() - # Non-transposed HF experts have weight layout (E, N, K) -> NT kernel. - # Transposed HF experts have weight layout (E, K, N) -> NN kernel. - m_grouped_bf16_gemm = ( - m_grouped_bf16_gemm_nn_contiguous if self.is_transposed else m_grouped_bf16_gemm_nt_contiguous - ) - - device = hidden_states.device - num_top_k = top_k_index.size(-1) - num_tokens = hidden_states.size(0) - hidden_dim = hidden_states.size(-1) - - # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - sample_weights = top_k_weights.reshape(-1) # (S,) - expert_ids = top_k_index.reshape(-1) # (S,) - - # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, - # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their - # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul - # contributes nothing. - # Sort by expert for grouped processing - expert_ids_g, perm = torch.sort(expert_ids) - selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] - - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 - sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout - ) - - if self.has_bias: - # Clamp now that the layout has been built — needed for the per-row bias gather below to stay - # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. - expert_ids_g.clamp_(0, self.num_experts - 1) - - # --- Up projection per expert (deep-gemm grouped contiguous, bf16) --- - w_up = self.gate_up_proj if self.has_gate else self.up_proj - # Output dim is the last weight axis when transposed (E, K, N), second axis when not (E, N, K). - up_out_dim = w_up.shape[-1] if self.is_transposed else w_up.shape[1] - act = _pad_for_deepgemm(selected_hidden_states_g, sorted_to_padded, total_padded_rows) - proj_out = torch.empty(total_padded_rows, up_out_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm(act, w_up, proj_out, grouped_layout, use_psum_layout=use_psum_layout) - - # The kernel has no bias input -> add per-expert bias in-place on the unpadded slice; - # padding rows get discarded at unpad time. - if self.has_bias: - up_bias = self.gate_up_proj_bias if self.has_gate else self.up_proj_bias - proj_out.index_add_(0, sorted_to_padded, up_bias[expert_ids_g]) - - # Apply gating or activation - if self.has_gate: - proj_out = self._apply_gate(proj_out) - else: - proj_out = self.act_fn(proj_out) - - # --- Down projection per expert (deep-gemm grouped contiguous, bf16) --- - # Zero-init: unpad later reads sentinel-row positions the kernel never writes. - out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=hidden_states.dtype) - m_grouped_bf16_gemm(proj_out, self.down_proj, out, grouped_layout, use_psum_layout=use_psum_layout) - - if self.has_bias: - out.index_add_(0, sorted_to_padded, self.down_proj_bias[expert_ids_g]) - - # Remove padding rows - out = _unpad_from_deepgemm_contiguous_layout(out, sorted_to_padded) - - # Apply routing weights - weighted_out = out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - - # Restore original order - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - weighted_out = weighted_out[inv_perm] - - # Accumulate results using deterministic reshape+sum instead of index_add_ - # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) - final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) - - return final_hidden_states.to(hidden_states.dtype) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f08329003df4..c51d2322fe36 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -23,8 +23,7 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import is_kernels_available -from .deepgemm import fp8_deepgemm_experts_forward, fp8_deepgemm_matmul +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -86,6 +85,162 @@ def _load_triton_kernel(): return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + +@functools.cache +def _load_deepgemm_kernel(): + """ + Load deep-gemm once and return its required symbols. + + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. + + Returns: + Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8) + from the deep-gemm kernel. + """ + if not is_kernels_available(): + raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + + if not torch.cuda.is_available(): + raise ImportError( + "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + + # deep-gemm requires CUDA runtime >= 12.3 + cuda_major, cuda_minor = get_cuda_runtime_version() + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): + raise ImportError( + f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + "Please upgrade your CUDA toolkit or use a different `experts_implementation`." + ) + + kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) + + fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) + m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + + missing = [ + name + for name, attr in [ + ("fp8_gemm_nt", fp8_gemm_nt), + ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), + ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ] + if attr is None + ] + if missing: + raise ImportError( + f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + "Please update the `kernels` package (`pip install -U kernels`)." + ) + + return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 + + +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + fp8_gemm_nt, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. + + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so deep-gemm skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -355,6 +510,98 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) +def fp8_deepgemm_experts_forward( + self: torch.nn.Module, + hidden_states: torch.Tensor, + top_k_index: torch.Tensor, + top_k_weights: torch.Tensor, +) -> torch.Tensor: + if self.activation_scheme == "static": + raise NotImplementedError( + "deepgemm experts dispatch does not support activation_scheme='static'. " + "Use the default eager dispatch or switch to activation_scheme='dynamic'." + ) + if self.block_size is None: + raise ValueError( + "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "but got per-tensor quantization (block_size=None)." + ) + if self.block_size[0] != 128 or self.block_size[1] != 128: + raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + + _, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 = _load_deepgemm_kernel() + + device = hidden_states.device + num_top_k = top_k_index.size(-1) + num_tokens = hidden_states.size(0) + hidden_dim = hidden_states.size(-1) + + # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) + sample_weights = top_k_weights.reshape(-1) # (S,) + expert_ids = top_k_index.reshape(-1) # (S,) + + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul + # contributes nothing. + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] + sample_weights_g = sample_weights[perm] # inherits zeros at invalid EP slots + + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout + ) + + # --- Up projection per expert (deep-gemm grouped contiguous) --- + w_up = self.gate_up_proj if self.has_gate else self.up_proj + ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv + act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) + act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + ) + + # Apply gating or activation + if self.has_gate: + proj_out = self._apply_gate(proj_out) + else: + proj_out = self.act_fn(proj_out) + + # --- Down projection per expert (deep-gemm grouped contiguous) --- + proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + # Zero-init: unpad later reads sentinel-row positions the kernel never writes. + proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + m_grouped_fp8_gemm_nt_contiguous( + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, + ) + + # Remove padding rows + proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) + + # Apply routing weights + weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) + weighted_out = weighted_out[inv_perm] + + # Accumulate results using deterministic reshape+sum instead of index_add_ + # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) + final_hidden_states = weighted_out.view(num_tokens, num_top_k, hidden_dim).sum(dim=1) + + return final_hidden_states.to(hidden_states.dtype) + + class FP8Experts(nn.Module): def __init__( self, diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 4f1f9c315959..1ceb9e167409 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -24,7 +24,6 @@ is_torch_less_or_equal, is_torchdynamo_compiling, ) -from .deepgemm import deepgemm_experts_forward from .sonicmoe import sonicmoe_experts_forward @@ -459,7 +458,6 @@ class ExpertsInterface(GeneralInterface): _global_mapping = { "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, - "deepgemm": deepgemm_experts_forward, "sonicmoe": sonicmoe_experts_forward, } From 996d67d0ce9fa46d46a82c4d552215305ee960cd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:26:27 +0200 Subject: [PATCH 08/22] fix --- src/transformers/integrations/finegrained_fp8.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index b6b437761eed..c75d66087cf7 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -36,6 +36,13 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max +def _first_attr(obj, *names): + for name in names: + if hasattr(obj, name): + return getattr(obj, name) + raise AttributeError(f"{type(obj).__name__} has none of: {names}") + + @functools.cache def _load_triton_kernel(): """ From d033a8309a538bb476298d931ec70032792000dd Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 11:38:43 +0200 Subject: [PATCH 09/22] prefix --- .../integrations/finegrained_fp8.py | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index c75d66087cf7..6a4ae50c8f6a 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -108,7 +108,7 @@ def _load_deepgemm_kernel(): required symbols are not found. Returns: - Tuple of (fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8) + Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) from the deep-gemm kernel. """ if not is_kernels_available(): @@ -142,16 +142,16 @@ def _load_deepgemm_kernel(): "has a build matching the current torch/CUDA." ) - fp8_gemm_nt = getattr(kernel, "fp8_gemm_nt", None) - m_grouped_fp8_gemm_nt_contiguous = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) - per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") + deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) + deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) + deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") missing = [ name for name, attr in [ - ("fp8_gemm_nt", fp8_gemm_nt), - ("m_grouped_fp8_gemm_nt_contiguous", m_grouped_fp8_gemm_nt_contiguous), - ("utils.per_token_cast_to_fp8", per_token_cast_to_fp8), + ("fp8_gemm_nt", deepgemm_fp8_matmul), + ("m_grouped_fp8_gemm_nt_contiguous", deepgemm_grouped_fp8_matmul), + ("utils.per_token_cast_to_fp8", deepgemm_per_token_cast_to_fp8), ] if attr is None ] @@ -161,7 +161,7 @@ def _load_deepgemm_kernel(): "Please update the `kernels` package (`pip install -U kernels`)." ) - return fp8_gemm_nt, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 + return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 def fp8_deepgemm_matmul( @@ -181,11 +181,11 @@ def fp8_deepgemm_matmul( Bs: (N//128, K//128) float32 — per-block weight scales output_dtype: desired output dtype. """ - fp8_gemm_nt, _, _ = _load_deepgemm_kernel() + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() A_2d = A.view(-1, A.shape[-1]) As_2d = As.view(-1, As.shape[-1]) output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - fp8_gemm_nt((A_2d, As_2d.float()), (B, Bs.float()), output) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) return output.view(A.shape[:-1] + (B.shape[0],)) @@ -536,7 +536,7 @@ def fp8_deepgemm_experts_forward( if self.block_size[0] != 128 or self.block_size[1] != 128: raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") - _, m_grouped_fp8_gemm_nt_contiguous, per_token_cast_to_fp8 = _load_deepgemm_kernel() + _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -565,11 +565,11 @@ def fp8_deepgemm_experts_forward( # --- Up projection per expert (deep-gemm grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv - act_fp8, act_scales = per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) + act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( + deepgemm_grouped_fp8_matmul( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -580,10 +580,10 @@ def fp8_deepgemm_experts_forward( proj_out = self.act_fn(proj_out) # --- Down projection per expert (deep-gemm grouped contiguous) --- - proj_fp8, proj_scales = per_token_cast_to_fp8(proj_out, use_ue8m0=False) + proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) # Zero-init: unpad later reads sentinel-row positions the kernel never writes. proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) - m_grouped_fp8_gemm_nt_contiguous( + deepgemm_grouped_fp8_matmul( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), proj_out, From e15cfe6ad62f13e87cfe07353c787f0ba7fcb3d0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:15:40 +0200 Subject: [PATCH 10/22] move --- .../integrations/finegrained_fp8.py | 201 +++++++++--------- 1 file changed, 100 insertions(+), 101 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 6a4ae50c8f6a..f268018314ac 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -101,7 +101,7 @@ def _load_triton_kernel(): @functools.cache def _load_deepgemm_kernel(): """ - Load deep-gemm once and return its required symbols. + Load DeepGEMM once and return its required symbols. Raises: ImportError if CUDA/hardware requirements are not met, or the kernel or @@ -109,36 +109,36 @@ def _load_deepgemm_kernel(): Returns: Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) - from the deep-gemm kernel. + from the DeepGEMM kernel. """ if not is_kernels_available(): - raise ImportError("deep-gemm kernel requires the `kernels` package. Install it with `pip install -U kernels`.") + raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") if not torch.cuda.is_available(): raise ImportError( - "deep-gemm kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." ) - # deep-gemm requires Hopper (SM90) or newer for FP8 WGMMA instructions + # DeepGEMM requires Hopper (SM90) or newer for FP8 WGMMA instructions major = torch.cuda.get_device_capability()[0] if major < 9: raise ImportError( - f"deep-gemm requires a Hopper (SM90+) or newer GPU, but the current device " + f"DeepGEMM requires a Hopper (SM90+) or newer GPU, but the current device " f"has compute capability {major}.x. Use a different `experts_implementation`." ) - # deep-gemm requires CUDA runtime >= 12.3 + # DeepGEMM requires CUDA runtime >= 12.3 cuda_major, cuda_minor = get_cuda_runtime_version() if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( - f"deep-gemm requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " + f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - kernel = lazy_load_kernel("deep-gemm") + kernel = lazy_load_kernel("DeepGEMM") if kernel is None: raise ImportError( - "Failed to load the deep-gemm kernel — check that `kernels-community/deep-gemm` " + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " "has a build matching the current torch/CUDA." ) @@ -157,97 +157,13 @@ def _load_deepgemm_kernel(): ] if missing: raise ImportError( - f"deep-gemm kernel is missing required symbols: {', '.join(missing)}. " + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via deep-gemm's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - -def _build_deepgemm_contiguous_layout( - expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool -) -> tuple: - """Build the TMA-aligned layout deep-gemm's grouped GEMM expects. - - Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes - expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or - per-row expert ids with -1 for padding on Hopper. - - Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) - are routed past the last aligned expert block and marked `-1` in the Hopper layout (and - excluded from the Blackwell cumsum), so deep-gemm skips them. - """ - device = expert_ids_sorted.device - num_tokens = expert_ids_sorted.size(0) - # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. - tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() - aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU->CPU sync; padding rows are skipped by deep-gemm. - total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) - - # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the - # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, - # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the - # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. - padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) - sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - - if use_psum_layout: # Blackwell (SM100+) - # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= - # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler - # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` - # between experts only matches the padded tensor when the stored cumsum is over aligned counts. - grouped_layout = aligned_tokens_per_expert.cumsum(0).int() - else: - # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). - grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) - - return sorted_to_padded, grouped_layout, total_padded_rows - - -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded - - -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: - """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] - - def _cdiv(a: int, b: int) -> int: """Ceiling division.""" return (a + b - 1) // b @@ -425,7 +341,6 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - # Let torch promote bf16 `proj_out` × fp32 `sample_weights` to fp32 for the reduction below. weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -517,6 +432,90 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) +def fp8_deepgemm_matmul( + A: torch.Tensor, + B: torch.Tensor, + As: torch.Tensor, + Bs: torch.Tensor, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + """ + FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. + + Args: + A: (M, K) float8_e4m3fn — quantized activations + B: (N, K) float8_e4m3fn — quantized weights + As: (M, K//128) float32 — per-block activation scales + Bs: (N//128, K//128) float32 — per-block weight scales + output_dtype: desired output dtype. + """ + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) + + +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. + + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. + + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so DeepGEMM skips them. + """ + device = expert_ids_sorted.device + num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. + tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() + aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment + # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. + total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. + padding_per_expert = aligned_tokens_per_expert - tokens_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) + sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] + + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() + else: + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). + grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) + + return sorted_to_padded, grouped_layout, total_padded_rows + + +def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: + """Pad a sorted tensor into the TMA-aligned contiguous layout. + + Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) + or via the psum offsets (Blackwell), so their values never enter the computation. + """ + padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) + padded[sorted_to_padded] = x + return padded + + +def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: + """Remove padding rows from the TMA-aligned contiguous layout.""" + return x_padded[sorted_to_padded] + + def fp8_deepgemm_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -530,11 +529,11 @@ def fp8_deepgemm_experts_forward( ) if self.block_size is None: raise ValueError( - "deep-gemm requires block-wise quantization (block_size=[128, 128]), " + "DeepGEMM requires block-wise quantization (block_size=[128, 128]), " "but got per-tensor quantization (block_size=None)." ) if self.block_size[0] != 128 or self.block_size[1] != 128: - raise ValueError(f"deep-gemm requires block_size=(128, 128), got {self.block_size}") + raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() @@ -549,7 +548,7 @@ def fp8_deepgemm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and deep-gemm skips them — so sentinels cost no real GEMM compute. Their + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. Their # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul # contributes nothing. # Sort by expert for grouped processing @@ -562,7 +561,7 @@ def fp8_deepgemm_experts_forward( expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) - # --- Up projection per expert (deep-gemm grouped contiguous) --- + # --- Up projection per expert (DeepGEMM grouped contiguous) --- w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) @@ -579,7 +578,7 @@ def fp8_deepgemm_experts_forward( else: proj_out = self.act_fn(proj_out) - # --- Down projection per expert (deep-gemm grouped contiguous) --- + # --- Down projection per expert (DeepGEMM grouped contiguous) --- proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) # Zero-init: unpad later reads sentinel-row positions the kernel never writes. proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) From 10b6d904105bef0850afd69a9f177e0ff9d22389 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:17:59 +0200 Subject: [PATCH 11/22] fix --- src/transformers/integrations/finegrained_fp8.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index f268018314ac..bd20894c382c 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -135,7 +135,7 @@ def _load_deepgemm_kernel(): "Please upgrade your CUDA toolkit or use a different `experts_implementation`." ) - kernel = lazy_load_kernel("DeepGEMM") + kernel = lazy_load_kernel("deep-gemm") if kernel is None: raise ImportError( "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " @@ -524,7 +524,7 @@ def fp8_deepgemm_experts_forward( ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " + "DeepGEMM experts dispatch does not support activation_scheme='static'. " "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) if self.block_size is None: From d4a6b3056f701dc0c307b02b73a04a890b0bfc30 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 12:30:20 +0200 Subject: [PATCH 12/22] remove comment --- src/transformers/integrations/finegrained_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index bd20894c382c..910eab7838c1 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -554,7 +554,7 @@ def fp8_deepgemm_experts_forward( # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] - sample_weights_g = sample_weights[perm] # inherits zeros at invalid EP slots + sample_weights_g = sample_weights[perm] use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( From 1d6054ff5904407bda8e47bbddd95971f85582e0 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:00:12 +0200 Subject: [PATCH 13/22] fix unintilized outputs leaking --- .../integrations/finegrained_fp8.py | 24 +++++++++++++------ src/transformers/integrations/moe.py | 10 ++++++-- 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 910eab7838c1..e5a4479f178e 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -375,8 +375,9 @@ def fp8_grouped_mm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips - # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are - # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. + # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -420,6 +421,11 @@ def fp8_grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) @@ -548,9 +554,9 @@ def fp8_deepgemm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the - # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. Their - # routing weights are already zero (RouterParallel masks them at dispatch) so the weighted mul - # contributes nothing. + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. + # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -580,8 +586,7 @@ def fp8_deepgemm_experts_forward( # --- Down projection per expert (DeepGEMM grouped contiguous) --- proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - # Zero-init: unpad later reads sentinel-row positions the kernel never writes. - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (proj_fp8, proj_scales), (self.down_proj, self.down_proj_scale_inv.float()), @@ -596,6 +601,11 @@ def fp8_deepgemm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 1ceb9e167409..4ef11fe029b7 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -381,8 +381,9 @@ def grouped_mm_experts_forward( # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows - # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Their routing weights are - # already zero (RouterParallel masks them at dispatch) so the weighted mul contributes nothing. + # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. + # Sort by expert for grouped processing expert_ids_g, perm = torch.sort(expert_ids) selected_hidden_states_g = hidden_states[perm // num_top_k] @@ -438,6 +439,11 @@ def grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by grouped_mm, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order inv_perm = torch.empty_like(perm) inv_perm[perm] = torch.arange(perm.size(0), device=device) From 137393cda9bc902f7f8dce942dd68ed25be28c2a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:07:37 +0200 Subject: [PATCH 14/22] revert unnecessary changes --- .../integrations/finegrained_fp8.py | 36 +++++++++++-------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index e5a4479f178e..684da70f8610 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -419,7 +419,7 @@ def fp8_grouped_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here @@ -506,20 +506,27 @@ def _build_deepgemm_contiguous_layout( return sorted_to_padded, grouped_layout, total_padded_rows -def _pad_for_deepgemm(x: torch.Tensor, sorted_to_padded: torch.Tensor, total_padded_rows: int) -> torch.Tensor: - """Pad a sorted tensor into the TMA-aligned contiguous layout. - - Padding rows are left uninitialized — the kernel skips them via `grouped_layout=-1` (Hopper) - or via the psum offsets (Blackwell), so their values never enter the computation. - """ - padded = torch.empty(total_padded_rows, *x.shape[1:], device=x.device, dtype=x.dtype) - padded[sorted_to_padded] = x - return padded +def _pad_to_deepgemm_contiguous_layout( + hidden_states: torch.Tensor, + scales: torch.Tensor, + sorted_to_padded: torch.Tensor, + total_padded_rows: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """Pad sorted hidden states and scales into the TMA-aligned contiguous layout.""" + hidden_padded = torch.zeros( + total_padded_rows, hidden_states.shape[1], device=hidden_states.device, dtype=hidden_states.dtype + ) + hidden_padded[sorted_to_padded] = hidden_states + scales_padded = torch.zeros(total_padded_rows, scales.shape[1], device=hidden_states.device, dtype=torch.float32) + scales_padded[sorted_to_padded] = scales + return hidden_padded, scales_padded -def _unpad_from_deepgemm_contiguous_layout(x_padded: torch.Tensor, sorted_to_padded: torch.Tensor) -> torch.Tensor: +def _unpad_from_deepgemm_contiguous_layout( + hidden_states_padded: torch.Tensor, sorted_to_padded: torch.Tensor +) -> torch.Tensor: """Remove padding rows from the TMA-aligned contiguous layout.""" - return x_padded[sorted_to_padded] + return hidden_states_padded[sorted_to_padded] def fp8_deepgemm_experts_forward( @@ -571,8 +578,7 @@ def fp8_deepgemm_experts_forward( w_up = self.gate_up_proj if self.has_gate else self.up_proj ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) - act_fp8 = _pad_for_deepgemm(act_fp8, sorted_to_padded, total_padded_rows) - act_scales = _pad_for_deepgemm(act_scales, sorted_to_padded, total_padded_rows) + act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout @@ -599,7 +605,7 @@ def fp8_deepgemm_experts_forward( proj_out = _unpad_from_deepgemm_contiguous_layout(proj_out, sorted_to_padded) # Apply routing weights - weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here From 774f90181dc5d7f8cea1e25b8dd46444b4ac524a Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:09:03 +0200 Subject: [PATCH 15/22] more unnecessary changes --- src/transformers/integrations/finegrained_fp8.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 684da70f8610..a07a0cdd37e2 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -36,6 +36,12 @@ _FP8_MAX = torch.finfo(_FP8_DTYPE).max +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. +# TMA is an H100 hardware addition that allows applications to asynchronously and +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. +_DEEPGEMM_M_ALIGNMENT = 128 + + def _first_attr(obj, *names): for name in names: if hasattr(obj, name): @@ -92,12 +98,6 @@ def _load_triton_kernel(): return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. -# TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. -_DEEPGEMM_M_ALIGNMENT = 128 - - @functools.cache def _load_deepgemm_kernel(): """ From 81230feeaf3c9234399755186394019bd5a21ee4 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 13:30:07 +0200 Subject: [PATCH 16/22] revert downcast --- src/transformers/integrations/finegrained_fp8.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index a07a0cdd37e2..64e9c3722c28 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -341,7 +341,7 @@ def fp8_batched_mm_experts_forward( ) # (S, hidden_dim) # Apply routing weights - weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) + weighted_out = proj_out * sample_weights.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ # (index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd) From 9f2ff08915bf865791ac8ef2ddbb79ccac317b5b Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 14:52:26 +0200 Subject: [PATCH 17/22] keep it simple --- .../integrations/finegrained_fp8.py | 44 +++++++------------ 1 file changed, 16 insertions(+), 28 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 64e9c3722c28..61190b480be3 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -128,7 +128,14 @@ def _load_deepgemm_kernel(): ) # DeepGEMM requires CUDA runtime >= 12.3 - cuda_major, cuda_minor = get_cuda_runtime_version() + try: + cuda_major, cuda_minor = get_cuda_runtime_version() + except OSError as e: + raise ImportError( + f"DeepGEMM requires CUDA runtime 12.3+, but libcudart could not be loaded ({e}). " + "Use a different `experts_implementation`." + ) from e + if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " @@ -197,14 +204,20 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - # 3-6x faster than Triton - return fp8_deepgemm_matmul(A, B, As, Bs, output_dtype=output_dtype) + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " "To use DeepGEMM FP8 matmul, ensure you have a Hopper (SM90+) or newer GPU with CUDA runtime 12.3+, " "and that the `kernels` package is installed and up to date (`pip install -U kernels`)." ) + else: + # 3-6x faster than Triton + A_2d = A.view(-1, A.shape[-1]) + As_2d = As.view(-1, As.shape[-1]) + output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) + deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) + return output.view(A.shape[:-1] + (B.shape[0],)) triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -438,31 +451,6 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def fp8_deepgemm_matmul( - A: torch.Tensor, - B: torch.Tensor, - As: torch.Tensor, - Bs: torch.Tensor, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - """ - FP8 dense matmul via DeepGEMM's `fp8_gemm_nt`. Block-wise 128x128 scales expected. - - Args: - A: (M, K) float8_e4m3fn — quantized activations - B: (N, K) float8_e4m3fn — quantized weights - As: (M, K//128) float32 — per-block activation scales - Bs: (N//128, K//128) float32 — per-block weight scales - output_dtype: desired output dtype. - """ - deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() - A_2d = A.view(-1, A.shape[-1]) - As_2d = As.view(-1, As.shape[-1]) - output = torch.empty(A_2d.shape[0], B.shape[0], device=A.device, dtype=output_dtype) - deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) - return output.view(A.shape[:-1] + (B.shape[0],)) - - def _build_deepgemm_contiguous_layout( expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool ) -> tuple: From c55b7b7863e864224390ea79f412a2a1830dfab5 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 14:59:15 +0200 Subject: [PATCH 18/22] guard deepgemm cuda version --- .../integrations/finegrained_fp8.py | 9 +------- src/transformers/utils/import_utils.py | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index 61190b480be3..f423f2f6b830 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -128,14 +128,7 @@ def _load_deepgemm_kernel(): ) # DeepGEMM requires CUDA runtime >= 12.3 - try: - cuda_major, cuda_minor = get_cuda_runtime_version() - except OSError as e: - raise ImportError( - f"DeepGEMM requires CUDA runtime 12.3+, but libcudart could not be loaded ({e}). " - "Use a different `experts_implementation`." - ) from e - + cuda_major, cuda_minor = get_cuda_runtime_version() if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( f"DeepGEMM requires CUDA runtime 12.3+, but found {cuda_major}.{cuda_minor}. " diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..756363ea6c52 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -222,14 +222,27 @@ def is_cuda_platform() -> bool: def get_cuda_runtime_version() -> tuple[int, int]: """Return the CUDA runtime version as (major, minor). - Unlike ``torch.version.cuda`` which reports the compile-time version, - this queries ``cudaRuntimeGetVersion`` from ``libcudart.so`` to get the - actual runtime version installed on the system. + Prefers a direct query of ``cudaRuntimeGetVersion`` via ``libcudart.so``. If that's + not on the system loader path (common with pip-installed torch that bundles its own + CUDA runtime), falls back to ``torch.version.cuda`` — which equals the bundled + runtime's version for pip wheels. Returns ``(0, 0)`` for CPU-only torch. """ import ctypes + try: + cudart = ctypes.CDLL("libcudart.so") + except OSError: + if not is_torch_available(): + return 0, 0 + import torch + + if getattr(torch.version, "cuda", None) is None: + return 0, 0 + + major, minor, *_ = torch.version.cuda.split(".") + return int(major), int(minor) + version = ctypes.c_int() - cudart = ctypes.CDLL("libcudart.so") cudart.cudaRuntimeGetVersion(ctypes.byref(version)) return version.value // 1000, (version.value % 1000) // 10 From 20858db8159171d2ca430766b21baf1a49493bd9 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Fri, 24 Apr 2026 15:19:19 +0200 Subject: [PATCH 19/22] fix style --- src/transformers/utils/import_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index 756363ea6c52..8654bd083ba2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -236,10 +236,11 @@ def get_cuda_runtime_version() -> tuple[int, int]: return 0, 0 import torch - if getattr(torch.version, "cuda", None) is None: + cuda_version = getattr(torch.version, "cuda", None) + if cuda_version is None: return 0, 0 - major, minor, *_ = torch.version.cuda.split(".") + major, minor, *_ = cuda_version.split(".") return int(major), int(minor) version = ctypes.c_int() From 89d2f0bb3eeb38c1e6431013f6c67b5cbf30a388 Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Sun, 26 Apr 2026 12:43:08 +0200 Subject: [PATCH 20/22] moe sentinel support --- src/transformers/integrations/hub_kernels.py | 6 ++++-- src/transformers/integrations/sonicmoe.py | 6 ++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 70a343424aa8..a362b9e114f2 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -289,7 +289,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, - "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, + "sonic-moe": {"repo_id": "IlyasMoutawwakil/sonic-moe", "revision": "main"}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} @@ -376,7 +376,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + # Entries in `_HUB_KERNEL_MAPPING` are vetted in-tree, so we trust non-`kernels-community` + # repos (e.g. user/team forks) without requiring the per-call `allow_all_kernels` flag. + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index d6eee485fea7..d32b698d5d74 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -112,6 +112,12 @@ def sonicmoe_experts_forward( router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) expert_ids = top_k_index.reshape(-1).int() + # EP sentinel handling: leave `expert_ids` unclamped — the kernel's metadata stage drops + # `expert_ids >= num_experts` from the per-expert histogram and masks them out of the + # scatter indices, so sentinels never enter the grouped GEMM. Their routing weights are + # already zero (RouterParallel masks them at dispatch), so the per-token reduction + # contributes nothing for sentinel slots. + # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) From 60db1ca0706885deec9efb185d097d88d5dc0277 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil Date: Sun, 26 Apr 2026 13:00:34 +0000 Subject: [PATCH 21/22] fix --- src/transformers/integrations/moe.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index 4ef11fe029b7..9cf262de0358 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -30,6 +30,12 @@ if is_torch_available(): import torch + # Patch the version-check helpers so dynamo doesn't trace into them — they transitively call + # `importlib.util.find_spec`, which dynamo refuses to trace. `assume_constant_result` makes + # dynamo evaluate them once at trace time and inline the bool, no body tracing. + is_torch_greater_or_equal = torch._dynamo.assume_constant_result(is_torch_greater_or_equal) + is_torch_less_or_equal = torch._dynamo.assume_constant_result(is_torch_less_or_equal) + logger = logging.get_logger(__name__) From 68b7b0fe2dc4e1877ad7af6e20b0e700de37e69c Mon Sep 17 00:00:00 2001 From: IlyasMoutawwakil Date: Mon, 27 Apr 2026 15:56:52 +0200 Subject: [PATCH 22/22] compilable sonicmoe --- src/transformers/integrations/sonicmoe.py | 78 ++++++++++++++++------- 1 file changed, 56 insertions(+), 22 deletions(-) diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index d32b698d5d74..912b98655519 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -25,7 +25,6 @@ import torch from ..utils import logging -from ..utils.import_utils import is_kernels_available from .hub_kernels import lazy_load_kernel @@ -47,8 +46,6 @@ def _load_sonic_kernel(): Returns: Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. """ - if not is_kernels_available(): - raise ImportError("sonic-moe kernel requires the `kernels` package. Install it with `pip install -U kernels`.") if not torch.cuda.is_available(): raise ImportError( @@ -90,6 +87,50 @@ def _load_sonic_kernel(): return ActivationType, moe_general_routing_inputs +@torch._dynamo.allow_in_graph +def _sonicmoe_wrapper( + hidden_states: torch.Tensor, + router_scores: torch.Tensor, + expert_ids: torch.Tensor, + token_idx: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + w2: torch.Tensor, + b2: torch.Tensor | None, + act_name: str, + num_experts: int, + concat_layout: bool, + is_inference_mode_enabled: bool, +) -> torch.Tensor: + """Module-level shim around `moe_general_routing_inputs` so `allow_in_graph` can wrap it. + + sonicmoe asserts `not torch.compiler.is_compiling()` internally because it dispatches + CuteDSL kernels, which Dynamo can't trace. `allow_in_graph` keeps the call in the FX + graph as a single opaque node (no tracing into the body, no graph break) while still + running the real Python at runtime — autograd through `_UpProjection` / `_DownProjection` + flows normally. The decorator must be applied at module load time, not inside the compiled + function — hence this shim plus the `allow_in_graph` decorator above. + """ + ActivationType, moe_general_routing_inputs = _load_sonic_kernel() + activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) + output, _ = moe_general_routing_inputs( + hidden_states, + router_scores, + token_idx, + expert_ids, + w1, + b1, + w2, + b2, + E=num_experts, + activation_type=activation_type, + is_inference_mode_enabled=is_inference_mode_enabled, + concat_layout=concat_layout, + stream_id=None, + ) + return output + + def sonicmoe_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -101,8 +142,6 @@ def sonicmoe_experts_forward( if hidden_states.device.type != "cuda": raise ValueError("sonicmoe requires CUDA device") - ActivationType, moe_general_routing_inputs = _load_sonic_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -120,8 +159,6 @@ def sonicmoe_experts_forward( # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() - activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) - # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). @@ -131,20 +168,17 @@ def sonicmoe_experts_forward( b1 = self.gate_up_proj_bias if self.has_bias else None b2 = self.down_proj_bias if self.has_bias else None - output, _ = moe_general_routing_inputs( - hidden_states, - router_scores, - token_idx, - expert_ids, - w1, - b1, - w2, - b2, - E=self.num_experts, - activation_type=activation_type, - stream_id=torch.cuda.current_stream(device).cuda_stream, - is_inference_mode_enabled=not torch.is_grad_enabled(), + return _sonicmoe_wrapper( + hidden_states=hidden_states, + router_scores=router_scores, + expert_ids=expert_ids, + token_idx=token_idx, + w1=w1, + b1=b1, + w2=w2, + b2=b2, + act_name=act_name, + num_experts=self.num_experts, concat_layout=self.is_concatenated, + is_inference_mode_enabled=not torch.is_grad_enabled(), ) - - return output