diff --git a/src/transformers/integrations/finegrained_fp8.py b/src/transformers/integrations/finegrained_fp8.py index c64f1ce23ec2..f423f2f6b830 100644 --- a/src/transformers/integrations/finegrained_fp8.py +++ b/src/transformers/integrations/finegrained_fp8.py @@ -11,6 +11,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + +import functools + import torch import torch.nn as nn from torch.nn import functional as F @@ -19,7 +23,7 @@ from ..core_model_loading import ConversionOps, _IdentityOp from ..quantizers.quantizers_utils import should_convert_module from ..utils import logging -from ..utils.import_utils import get_cuda_runtime_version, resolve_internal_import +from ..utils.import_utils import get_cuda_runtime_version, is_kernels_available, resolve_internal_import from .hub_kernels import lazy_load_kernel from .moe import ExpertsInterface, use_experts_implementation @@ -31,26 +35,12 @@ _FP8_MIN = torch.finfo(_FP8_DTYPE).min _FP8_MAX = torch.finfo(_FP8_DTYPE).max -# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM + +# DeepGEMM requires M-dimension alignment to 128 for TMA-based contiguous grouped GEMM. # TMA is an H100 hardware addition that allows applications to asynchronously and -# bi-directionally transfer 1D-5D tensors between GPU global and shared memory +# bi-directionally transfer 1D-5D tensors between GPU global and shared memory. _DEEPGEMM_M_ALIGNMENT = 128 -# Lazily-loaded finegrained-fp8 Triton kernel functions (populated by _load_triton_kernel) -triton_fp8_matmul = None -triton_fp8_act_quant = None -triton_batched_fp8_matmul = None -triton_grouped_fp8_matmul = None -# _triton_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_triton_available = None - -# Lazily-loaded DeepGEMM kernel functions (populated by _load_deepgemm_kernel) -deepgemm_fp8_matmul = None -deepgemm_grouped_fp8_matmul = None -deepgemm_per_token_cast_to_fp8 = None -# _deepgemm_available: None = not yet attempted, True = loaded, False = failed (won't retry) -_deepgemm_available = None - def _first_attr(obj, *names): for name in names: @@ -59,27 +49,31 @@ def _first_attr(obj, *names): raise AttributeError(f"{type(obj).__name__} has none of: {names}") +@functools.cache def _load_triton_kernel(): - """Lazily load the finegrained-fp8 Triton kernel and extract functions. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded or required functions are missing. Only attempts loading once. """ - global \ - _triton_available, \ - triton_fp8_act_quant, \ - triton_fp8_matmul, \ - triton_batched_fp8_matmul, \ - triton_grouped_fp8_matmul + Load the finegrained-fp8 Triton kernel once and return its required symbols. - if _triton_available is not None: - if not _triton_available: - raise ImportError("finegrained-fp8 kernel is not available (previous load attempt failed).") - return + Raises: + ImportError if the `kernels` package is missing, or the kernel or required + symbols cannot be found. - _triton_available = False # mark attempted before any early exit + Returns: + Tuple of (w8a8_fp8_matmul, fp8_act_quant, w8a8_fp8_matmul_batched, + w8a8_fp8_matmul_grouped) from the finegrained-fp8 kernel. + """ + if not is_kernels_available(): + raise ImportError( + "finegrained-fp8 kernel requires the `kernels` package. Install it with `pip install -U kernels`." + ) kernel = lazy_load_kernel("finegrained-fp8") + if kernel is None: + raise ImportError( + "Failed to load the finegrained-fp8 kernel — check that `kernels-community/finegrained-fp8` " + "has a build matching the current torch/CUDA." + ) + triton_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul", None) triton_fp8_act_quant = getattr(kernel, "fp8_act_quant", None) triton_batched_fp8_matmul = getattr(kernel, "w8a8_fp8_matmul_batched", None) @@ -97,30 +91,29 @@ def _load_triton_kernel(): ] if missing: raise ImportError( - f"finegrained-fp8 kernel is missing required functions: {', '.join(missing)}. " + f"finegrained-fp8 kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) - _triton_available = True + return triton_fp8_matmul, triton_fp8_act_quant, triton_batched_fp8_matmul, triton_grouped_fp8_matmul +@functools.cache def _load_deepgemm_kernel(): - """Lazily load the DeepGEMM kernel and extract functions with proper names. - - Uses the hub kernels lazy loading pattern. Raises an error if the kernel - cannot be loaded, required functions are missing, or the hardware is insufficient. - Only attempts loading once. """ - global _deepgemm_available, deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 + Load DeepGEMM once and return its required symbols. - if _deepgemm_available is not None: - if not _deepgemm_available: - raise ImportError("DeepGEMM kernel is not available (previous load attempt failed).") - return + Raises: + ImportError if CUDA/hardware requirements are not met, or the kernel or + required symbols are not found. - _deepgemm_available = False # mark attempted before any early exit + Returns: + Tuple of (deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8) + from the DeepGEMM kernel. + """ + if not is_kernels_available(): + raise ImportError("DeepGEMM kernel requires the `kernels` package. Install it with `pip install -U kernels`.") - # DeepGEMM requires CUDA and a compatible GPU if not torch.cuda.is_available(): raise ImportError( "DeepGEMM kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." @@ -134,7 +127,7 @@ def _load_deepgemm_kernel(): f"has compute capability {major}.x. Use a different `experts_implementation`." ) - # DeepGEMM requires CUDA runtime ≥ 12.3. + # DeepGEMM requires CUDA runtime >= 12.3 cuda_major, cuda_minor = get_cuda_runtime_version() if cuda_major < 12 or (cuda_major == 12 and cuda_minor < 3): raise ImportError( @@ -143,6 +136,12 @@ def _load_deepgemm_kernel(): ) kernel = lazy_load_kernel("deep-gemm") + if kernel is None: + raise ImportError( + "Failed to load the DeepGEMM kernel — check that `kernels-community/deep-gemm` " + "has a build matching the current torch/CUDA." + ) + deepgemm_fp8_matmul = getattr(kernel, "fp8_gemm_nt", None) deepgemm_grouped_fp8_matmul = getattr(kernel, "m_grouped_fp8_gemm_nt_contiguous", None) deepgemm_per_token_cast_to_fp8 = resolve_internal_import(kernel, chained_path="utils.per_token_cast_to_fp8") @@ -158,11 +157,11 @@ def _load_deepgemm_kernel(): ] if missing: raise ImportError( - f"DeepGEMM kernel is missing required functions: {', '.join(missing)}. " + f"DeepGEMM kernel is missing required symbols: {', '.join(missing)}. " "Please update the `kernels` package (`pip install -U kernels`)." ) - _deepgemm_available = True + return deepgemm_fp8_matmul, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 def _cdiv(a: int, b: int) -> int: @@ -198,8 +197,7 @@ def w8a8_fp8_matmul( """ if block_size is not None and block_size[0] == block_size[1] == 128: try: - _load_deepgemm_kernel() - global deepgemm_fp8_matmul + deepgemm_fp8_matmul, _, _ = _load_deepgemm_kernel() except ImportError: logger.warning_once( "DeepGEMM kernel is not available or compatible, falling back to Triton finegrained-fp8 kernel. " @@ -214,8 +212,7 @@ def w8a8_fp8_matmul( deepgemm_fp8_matmul((A_2d, As_2d.float()), (B, Bs.float()), output) return output.view(A.shape[:-1] + (B.shape[0],)) - _load_triton_kernel() - global triton_fp8_matmul + triton_fp8_matmul, _, _, _ = _load_triton_kernel() return triton_fp8_matmul(A, B, As, Bs, block_size, output_dtype) @@ -269,8 +266,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: scale_inv = self.weight_scale_inv.contiguous() if self.activation_scheme == "dynamic": - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -290,7 +286,7 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: ) if self.bias is not None: - output = output + self.bias + output.add_(self.bias) return output.to(dtype=input.dtype) @@ -307,21 +303,22 @@ def fp8_batched_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_batched_fp8_matmul + _, _, triton_batched_fp8_matmul, _ = _load_triton_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + # Clamp EP sentinels so per-token weight indexing stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # --- Up projection per expert (FP8 batched) --- proj_out = triton_batched_fp8_matmul( @@ -371,8 +368,7 @@ def fp8_grouped_mm_experts_forward( "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) - _load_triton_kernel() - global triton_grouped_fp8_matmul + _, _, _, triton_grouped_fp8_matmul = _load_triton_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -380,18 +376,18 @@ def fp8_grouped_mm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and the grouped matmul skips + # rows beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. - expert_ids_g = expert_ids[perm] + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped processing. # histc instead of bincount avoids cuda-graph issues; @@ -431,7 +427,14 @@ def fp8_grouped_mm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -441,35 +444,45 @@ def fp8_grouped_mm_experts_forward( return final_hidden_states.to(hidden_states.dtype) -def _build_deepgemm_contiguous_layout(expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int) -> tuple: - """Build a TMA-aligned contiguous layout for DeepGEMM grouped GEMM. +def _build_deepgemm_contiguous_layout( + expert_ids_sorted: torch.Tensor, num_experts: int, alignment: int, use_psum_layout: bool +) -> tuple: + """Build the TMA-aligned layout DeepGEMM's grouped GEMM expects. - DeepGEMM requires M-dimension alignment per expert for TMA. This computes - the mapping from sorted token positions to padded row positions, and the - layout tensor that DeepGEMM uses to identify expert boundaries. + Returns `(sorted_to_padded, grouped_layout, total_padded_rows)`. `grouped_layout` encodes + expert boundaries as a cumsum of aligned counts on Blackwell (`use_psum_layout=True`) or + per-row expert ids with -1 for padding on Hopper. - Returns: - sorted_to_padded: (num_tokens,) index map from sorted position to padded row - grouped_layout: expert layout tensor (format depends on GPU architecture) - total_padded_rows: total number of rows including alignment padding + Accepts EP sentinels: values in `expert_ids_sorted` equal to `num_experts` (unclamped sentinels) + are routed past the last aligned expert block and marked `-1` in the Hopper layout (and + excluded from the Blackwell cumsum), so DeepGEMM skips them. """ device = expert_ids_sorted.device num_tokens = expert_ids_sorted.size(0) + # histc drops values > max, so EP sentinels (== num_experts) are excluded from the per-expert count. tokens_per_expert = torch.histc(expert_ids_sorted.int(), bins=num_experts, min=0, max=num_experts - 1).long() aligned_tokens_per_expert = ((tokens_per_expert + alignment - 1) // alignment) * alignment - # Upper bound avoids GPU→CPU sync; padding rows are skipped by DeepGEMM. + # Upper bound avoids GPU->CPU sync; padding rows are skipped by DeepGEMM. total_padded_rows = num_tokens + min(num_tokens, num_experts) * (alignment - 1) + # Zero-prepended inclusive cumsum of per-expert padding. Indices [0, num_experts) give the + # exclusive cumsum (padding before expert i) and index `num_experts` gives `sum(padding)`, + # which routes EP sentinels past all valid aligned expert blocks on Blackwell (where the + # kernel stops at `aligned_cumsum[-1]`) — so sentinels don't go through the GEMM. padding_per_expert = aligned_tokens_per_expert - tokens_per_expert - cumulative_padding = padding_per_expert.cumsum(0) - padding_per_expert + cumulative_padding = torch.nn.functional.pad(padding_per_expert.cumsum(0), (1, 0)) sorted_to_padded = torch.arange(num_tokens, device=device) + cumulative_padding[expert_ids_sorted] - if torch.cuda.get_device_capability(device)[0] >= 10: # Blackwell (SM100+) - grouped_layout = tokens_per_expert.cumsum(0).int() + if use_psum_layout: # Blackwell (SM100+) + # psum layout: cumsum of *aligned* per-expert counts — sentinels sit at positions >= + # `grouped_layout[-1]` (by construction of `cumulative_padding`), so the scheduler + # stops before them. The kernel's `num_m_blocks = ceil_div(layout[i] - align(layout[i-1], 128), BLOCK_M)` + # between experts only matches the padded tensor when the stored cumsum is over aligned counts. + grouped_layout = aligned_tokens_per_expert.cumsum(0).int() else: - # Hopper: per-row expert id, -1 for padding rows + # Hopper: per-row expert id, -1 for padding rows and for sentinel slots (kernel skips -1). grouped_layout = torch.full((total_padded_rows,), -1, device=device, dtype=torch.int32) - grouped_layout[sorted_to_padded] = expert_ids_sorted.int() + grouped_layout[sorted_to_padded] = torch.where(expert_ids_sorted < num_experts, expert_ids_sorted.int(), -1) return sorted_to_padded, grouped_layout, total_padded_rows @@ -505,7 +518,7 @@ def fp8_deepgemm_experts_forward( ) -> torch.Tensor: if self.activation_scheme == "static": raise NotImplementedError( - "deepgemm experts dispatch does not support activation_scheme='static'. " + "DeepGEMM experts dispatch does not support activation_scheme='static'. " "Use the default eager dispatch or switch to activation_scheme='dynamic'." ) if self.block_size is None: @@ -516,8 +529,7 @@ def fp8_deepgemm_experts_forward( if self.block_size[0] != 128 or self.block_size[1] != 128: raise ValueError(f"DeepGEMM requires block_size=(128, 128), got {self.block_size}") - _load_deepgemm_kernel() - global deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 + _, deepgemm_grouped_fp8_matmul, deepgemm_per_token_cast_to_fp8 = _load_deepgemm_kernel() device = hidden_states.device num_top_k = top_k_index.size(-1) @@ -525,22 +537,22 @@ def fp8_deepgemm_experts_forward( hidden_dim = hidden_states.size(-1) # S is the number of selected token-expert pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `_build_deepgemm_contiguous_layout` marks their positions as skipped (-1 on Hopper, beyond the + # cumsum on Blackwell), and DeepGEMM skips them — so sentinels cost no real GEMM compute. + # Sentinel rows are zeroed post-weighted-mul (see below), since the kernel leaves them uninitialized. - expert_ids_g = expert_ids[perm] + # Sort by expert for grouped processing + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] - # Build TMA-aligned contiguous layout for DeepGEMM + use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 sorted_to_padded, grouped_layout, total_padded_rows = _build_deepgemm_contiguous_layout( - expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT + expert_ids_g, self.num_experts, alignment=_DEEPGEMM_M_ALIGNMENT, use_psum_layout=use_psum_layout ) # --- Up projection per expert (DeepGEMM grouped contiguous) --- @@ -548,8 +560,7 @@ def fp8_deepgemm_experts_forward( ws_up = self.gate_up_proj_scale_inv if self.has_gate else self.up_proj_scale_inv act_fp8, act_scales = deepgemm_per_token_cast_to_fp8(selected_hidden_states_g, use_ue8m0=False) act_fp8, act_scales = _pad_to_deepgemm_contiguous_layout(act_fp8, act_scales, sorted_to_padded, total_padded_rows) - proj_out = torch.zeros(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) - use_psum_layout = torch.cuda.get_device_capability(device)[0] >= 10 + proj_out = torch.empty(total_padded_rows, w_up.shape[1], device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( (act_fp8, act_scales), (w_up, ws_up.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout ) @@ -561,12 +572,14 @@ def fp8_deepgemm_experts_forward( proj_out = self.act_fn(proj_out) # --- Down projection per expert (DeepGEMM grouped contiguous) --- - w_down = self.down_proj - ws_down = self.down_proj_scale_inv proj_fp8, proj_scales = deepgemm_per_token_cast_to_fp8(proj_out, use_ue8m0=False) - proj_out = torch.zeros(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) + proj_out = torch.empty(total_padded_rows, hidden_dim, device=device, dtype=torch.bfloat16) deepgemm_grouped_fp8_matmul( - (proj_fp8, proj_scales), (w_down, ws_down.float()), proj_out, grouped_layout, use_psum_layout=use_psum_layout + (proj_fp8, proj_scales), + (self.down_proj, self.down_proj_scale_inv.float()), + proj_out, + grouped_layout, + use_psum_layout=use_psum_layout, ) # Remove padding rows @@ -575,7 +588,14 @@ def fp8_deepgemm_experts_forward( # Apply routing weights weighted_out = proj_out * sample_weights_g.to(proj_out.dtype).unsqueeze(-1) # (S, hidden_dim) + # EP sentinel handling: `proj_out` rows past the valid expert blocks are left uninitialized by the kernel, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) + # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -703,8 +723,7 @@ def linear( scale = activation_scale.to(torch.float32) qinput = (input / scale).clamp(min=_FP8_MIN, max=_FP8_MAX).to(_FP8_DTYPE) else: - _load_triton_kernel() - global triton_fp8_act_quant + _, triton_fp8_act_quant, _, _ = _load_triton_kernel() qinput, scale = triton_fp8_act_quant( input, self.block_size[1] if self.block_size is not None else input.shape[-1] ) @@ -916,5 +935,5 @@ def convert( } @property - def reverse_op(self) -> "ConversionOps": + def reverse_op(self) -> ConversionOps: return _IdentityOp() diff --git a/src/transformers/integrations/hub_kernels.py b/src/transformers/integrations/hub_kernels.py index 70a343424aa8..a362b9e114f2 100644 --- a/src/transformers/integrations/hub_kernels.py +++ b/src/transformers/integrations/hub_kernels.py @@ -289,7 +289,7 @@ def register_kernel_mapping_transformers(*args, **kwargs): "falcon_mamba-ssm": {"repo_id": "kernels-community/mamba-ssm", "version": 1}, "finegrained-fp8": {"repo_id": "kernels-community/finegrained-fp8", "version": 1}, "deep-gemm": {"repo_id": "kernels-community/deep-gemm", "version": 1}, - "sonic-moe": {"repo_id": "kernels-community/sonic-moe", "version": 1}, + "sonic-moe": {"repo_id": "IlyasMoutawwakil/sonic-moe", "revision": "main"}, } _KERNEL_MODULE_MAPPING: dict[str, ModuleType | None] = {} @@ -376,7 +376,9 @@ def lazy_load_kernel(kernel_name: str, mapping: dict[str, ModuleType | None] = _ repo_id = _HUB_KERNEL_MAPPING[kernel_name]["repo_id"] revision = _HUB_KERNEL_MAPPING[kernel_name].get("revision", None) version = _HUB_KERNEL_MAPPING[kernel_name].get("version", None) - kernel = get_kernel(repo_id, revision=revision, version=version) + # Entries in `_HUB_KERNEL_MAPPING` are vetted in-tree, so we trust non-`kernels-community` + # repos (e.g. user/team forks) without requiring the per-call `allow_all_kernels` flag. + kernel = get_kernel(repo_id, revision=revision, version=version, allow_all_kernels=True) mapping[kernel_name] = kernel except FileNotFoundError: mapping[kernel_name] = None diff --git a/src/transformers/integrations/moe.py b/src/transformers/integrations/moe.py index c8a8e87f3621..9cf262de0358 100644 --- a/src/transformers/integrations/moe.py +++ b/src/transformers/integrations/moe.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations from collections.abc import Callable from functools import wraps @@ -29,6 +30,12 @@ if is_torch_available(): import torch + # Patch the version-check helpers so dynamo doesn't trace into them — they transitively call + # `importlib.util.find_spec`, which dynamo refuses to trace. `assume_constant_result` makes + # dynamo evaluate them once at trace time and inline the bool, no body tracing. + is_torch_greater_or_equal = torch._dynamo.assume_constant_result(is_torch_greater_or_equal) + is_torch_less_or_equal = torch._dynamo.assume_constant_result(is_torch_less_or_equal) + logger = logging.get_logger(__name__) @@ -102,7 +109,7 @@ def _batched_linear( out = torch.bmm(weight, input.unsqueeze(-1)).squeeze(-1) if bias is not None: - out = out + bias + out.add_(bias) return out @@ -113,24 +120,20 @@ def batched_mm_experts_forward( top_k_index: torch.Tensor, top_k_weights: torch.Tensor, ) -> torch.Tensor: - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) + # Replicate each token num_top_k times to align with the flattened (S,) routing tensors. + selected_hidden_states = hidden_states.repeat_interleave(num_top_k, dim=0) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - # When EP is enabled, tokens assigned to experts on other devices are marked with sentinel value >= num_experts - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) - - # Get current hidden states for selected samples - selected_hidden_states = hidden_states[token_idx] + # Clamp EP sentinels so `gate_up_proj[expert_ids]` stays in-bounds. Routing weights are already + # zero at sentinel slots (RouterParallel masks them at dispatch), so the weighted mul drops + # those contributions — we pay the wasted GEMM compute because batched_mm has no offset to skip. + expert_ids.clamp_(0, self.num_experts - 1) # Select gate_up or just up projection weights and biases if self.has_gate: @@ -162,9 +165,8 @@ def batched_mm_experts_forward( proj_out, selected_weights, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions + # Apply routing weights weighted_out = proj_out * sample_weights.unsqueeze(-1) # (S, hidden_dim) - weighted_out.masked_fill_(invalid_mask.unsqueeze(-1), 0.0) # Zero out invalid expert contributions # Accumulate results using deterministic reshape+sum instead of index_add_ # index_add_ with duplicate indices is non-deterministic on CUDA due to atomicAdd @@ -363,7 +365,7 @@ def _grouped_linear( if bias is not None: # We should be able to pass bias to the grouped_mm call, but it's not yet supported. - out = out + bias + out.add_(bias) return out @@ -379,24 +381,19 @@ def grouped_mm_experts_forward( num_tokens = hidden_states.size(0) hidden_dim = hidden_states.size(-1) - # Reshape for easier indexing # S is the number of selected tokens-experts pairs (S = num_tokens * num_top_k) - token_idx = torch.arange(num_tokens, device=device).unsqueeze(1).expand(-1, num_top_k).reshape(-1) # (S,) sample_weights = top_k_weights.reshape(-1) # (S,) expert_ids = top_k_index.reshape(-1) # (S,) - # Handle invalid expert IDs from Expert Parallelism (EP) - invalid_mask = expert_ids >= self.num_experts - expert_ids = expert_ids.clamp(0, self.num_experts - 1) + # EP sentinel handling: leave `expert_ids` unclamped so the sort pushes sentinels to the tail, + # `histc(max=num_experts-1)` drops them from `tokens_per_expert`, and grouped_mm skips rows + # beyond `offsets[-1]` — so sentinels cost no real GEMM compute. Sentinel rows are zeroed + # post-weighted-mul (see below), since the kernel leaves them uninitialized. # Sort by expert for grouped processing - perm = torch.argsort(expert_ids) - inv_perm = torch.empty_like(perm) - inv_perm[perm] = torch.arange(perm.size(0), device=device) - - expert_ids_g = expert_ids[perm] + expert_ids_g, perm = torch.sort(expert_ids) + selected_hidden_states_g = hidden_states[perm // num_top_k] sample_weights_g = sample_weights[perm] - selected_hidden_states_g = hidden_states[token_idx[perm]] # Compute offsets for grouped_mm # using histc instead of bincount to avoid cuda graph issues @@ -405,11 +402,17 @@ def grouped_mm_experts_forward( tokens_per_expert = torch.histc(histc_input, bins=self.num_experts, min=0, max=self.num_experts - 1) offsets = torch.cumsum(tokens_per_expert, dim=0, dtype=torch.int32) + if self.has_bias: + # Clamp now that the layout has been built — needed for the per-row bias gather below to stay + # in-bounds. Bias added to sentinel positions falls in rows the kernel skips, so harmless. + expert_ids_g.clamp_(0, self.num_experts - 1) + # Select expert weights and biases # NOTE: We keep all experts here and rely on offsets to target the active ones. # I have already implemented a version that only passes the active experts, but # to do so I had to use torch.unique which breaks the graph capture (data-dependent). # Also there were no speedup gains from it in my experiments, even in eager mode. + # NOTE: The grouped_mm kernel only targets the active experts / tokens via the offsets if self.has_gate: selected_weights = self.gate_up_proj selected_biases = self.gate_up_proj_bias[expert_ids_g] if self.has_bias else None @@ -439,12 +442,17 @@ def grouped_mm_experts_forward( proj_out, selected_weights, offsets, bias=selected_biases, is_transposed=self.is_transposed ) # (S, hidden_dim) - # Apply routing weights and zero out invalid expert contributions from EP + # Apply routing weights weighted_out = proj_out * sample_weights_g.unsqueeze(-1) # (S, hidden_dim) - invalid_mask_g = invalid_mask[perm] - weighted_out.masked_fill_(invalid_mask_g.unsqueeze(-1), 0.0) + + # EP sentinel handling: `proj_out` rows past `offsets[-1]` are left uninitialized by grouped_mm, + # so `proj_out[sentinel] * 0 = 0 * NaN = NaN` can leak from allocator pool reuse. Zero them here + # so the downstream reduction stays finite even when the routing weight was already zero. + weighted_out.masked_fill_((expert_ids_g >= self.num_experts).unsqueeze(-1), 0.0) # Restore original order + inv_perm = torch.empty_like(perm) + inv_perm[perm] = torch.arange(perm.size(0), device=device) weighted_out = weighted_out[inv_perm] # (S, hidden_dim) # Accumulate results using deterministic reshape+sum instead of index_add_ @@ -460,9 +468,9 @@ class ExpertsInterface(GeneralInterface): """Interface for registering custom experts forward functions.""" _global_mapping = { - "sonicmoe": sonicmoe_experts_forward, "batched_mm": batched_mm_experts_forward, "grouped_mm": grouped_mm_experts_forward, + "sonicmoe": sonicmoe_experts_forward, } def get_interface(self, experts_implementation: str, default: Callable) -> Callable: diff --git a/src/transformers/integrations/sonicmoe.py b/src/transformers/integrations/sonicmoe.py index e322bb4bc061..912b98655519 100644 --- a/src/transformers/integrations/sonicmoe.py +++ b/src/transformers/integrations/sonicmoe.py @@ -18,6 +18,8 @@ Requirements: CUDA, `kernels`, `nvidia-cutlass-dsl`, has_gate=True. """ +from __future__ import annotations + import functools import torch @@ -38,16 +40,31 @@ def _load_sonic_kernel(): Load sonic-moe once and return its required symbols. Raises: - ImportError if the kernel or required symbols are not found. + ImportError if CUDA/hardware requirements are not met, or if the kernel or + required symbols are not found. Returns: Tuple of (ActivationType, moe_general_routing_inputs function) from the sonic-moe kernel. """ + if not torch.cuda.is_available(): + raise ImportError( + "sonic-moe kernel requires CUDA, but CUDA is not available. Use a different `experts_implementation`." + ) + + # sonic-moe requires Hopper (SM90) or newer + major = torch.cuda.get_device_capability()[0] + if major < 9: + raise ImportError( + f"sonic-moe requires a Hopper (SM90+) or newer GPU, but the current device " + f"has compute capability {major}.x. Use a different `experts_implementation`." + ) + kernel = lazy_load_kernel("sonic-moe") if kernel is None: raise ImportError( - "sonic-moe kernel not found. Make sure you have the `kernels` and `nvidia-cutlass-dsl` packages installed." + "Failed to load the sonic-moe kernel — check that `kernels-community/sonic-moe` " + "has a build matching the current torch/CUDA." ) ActivationType = getattr(getattr(kernel, "enums", None), "ActivationType", None) @@ -70,6 +87,50 @@ def _load_sonic_kernel(): return ActivationType, moe_general_routing_inputs +@torch._dynamo.allow_in_graph +def _sonicmoe_wrapper( + hidden_states: torch.Tensor, + router_scores: torch.Tensor, + expert_ids: torch.Tensor, + token_idx: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + w2: torch.Tensor, + b2: torch.Tensor | None, + act_name: str, + num_experts: int, + concat_layout: bool, + is_inference_mode_enabled: bool, +) -> torch.Tensor: + """Module-level shim around `moe_general_routing_inputs` so `allow_in_graph` can wrap it. + + sonicmoe asserts `not torch.compiler.is_compiling()` internally because it dispatches + CuteDSL kernels, which Dynamo can't trace. `allow_in_graph` keeps the call in the FX + graph as a single opaque node (no tracing into the body, no graph break) while still + running the real Python at runtime — autograd through `_UpProjection` / `_DownProjection` + flows normally. The decorator must be applied at module load time, not inside the compiled + function — hence this shim plus the `allow_in_graph` decorator above. + """ + ActivationType, moe_general_routing_inputs = _load_sonic_kernel() + activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) + output, _ = moe_general_routing_inputs( + hidden_states, + router_scores, + token_idx, + expert_ids, + w1, + b1, + w2, + b2, + E=num_experts, + activation_type=activation_type, + is_inference_mode_enabled=is_inference_mode_enabled, + concat_layout=concat_layout, + stream_id=None, + ) + return output + + def sonicmoe_experts_forward( self: torch.nn.Module, hidden_states: torch.Tensor, @@ -81,8 +142,6 @@ def sonicmoe_experts_forward( if hidden_states.device.type != "cuda": raise ValueError("sonicmoe requires CUDA device") - ActivationType, moe_general_routing_inputs = _load_sonic_kernel() - device = hidden_states.device num_top_k = top_k_index.size(-1) num_tokens = hidden_states.size(0) @@ -92,10 +151,14 @@ def sonicmoe_experts_forward( router_scores = top_k_weights.reshape(-1).to(hidden_states.dtype) expert_ids = top_k_index.reshape(-1).int() + # EP sentinel handling: leave `expert_ids` unclamped — the kernel's metadata stage drops + # `expert_ids >= num_experts` from the per-expert histogram and masks them out of the + # scatter indices, so sentinels never enter the grouped GEMM. Their routing weights are + # already zero (RouterParallel masks them at dispatch), so the per-token reduction + # contributes nothing for sentinel slots. + # Map activation function act_name = getattr(self.config, "hidden_act", "silu").lower() - activation_type = getattr(ActivationType, ACT_MAP.get(act_name, "swiglu").upper(), ActivationType.SWIGLU) - # Permute weights as expected by sonic-moe (E=num_experts, H=hidden_size, I=intermediate_size). # Non-transposed: gate_up_proj is (E, 2*I, H), down_proj is (E, H, I) -> permute(1, 2, 0). # Transposed: gate_up_proj is (E, H, 2*I), down_proj is (E, I, H) -> permute(2, 1, 0). @@ -105,20 +168,17 @@ def sonicmoe_experts_forward( b1 = self.gate_up_proj_bias if self.has_bias else None b2 = self.down_proj_bias if self.has_bias else None - output, _ = moe_general_routing_inputs( - hidden_states, - router_scores, - token_idx, - expert_ids, - w1, - b1, - w2, - b2, - E=self.num_experts, - activation_type=activation_type, - stream_id=torch.cuda.current_stream(device).cuda_stream, - is_inference_mode_enabled=not torch.is_grad_enabled(), + return _sonicmoe_wrapper( + hidden_states=hidden_states, + router_scores=router_scores, + expert_ids=expert_ids, + token_idx=token_idx, + w1=w1, + b1=b1, + w2=w2, + b2=b2, + act_name=act_name, + num_experts=self.num_experts, concat_layout=self.is_concatenated, + is_inference_mode_enabled=not torch.is_grad_enabled(), ) - - return output diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index de11d23cbecf..8654bd083ba2 100644 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -222,14 +222,28 @@ def is_cuda_platform() -> bool: def get_cuda_runtime_version() -> tuple[int, int]: """Return the CUDA runtime version as (major, minor). - Unlike ``torch.version.cuda`` which reports the compile-time version, - this queries ``cudaRuntimeGetVersion`` from ``libcudart.so`` to get the - actual runtime version installed on the system. + Prefers a direct query of ``cudaRuntimeGetVersion`` via ``libcudart.so``. If that's + not on the system loader path (common with pip-installed torch that bundles its own + CUDA runtime), falls back to ``torch.version.cuda`` — which equals the bundled + runtime's version for pip wheels. Returns ``(0, 0)`` for CPU-only torch. """ import ctypes + try: + cudart = ctypes.CDLL("libcudart.so") + except OSError: + if not is_torch_available(): + return 0, 0 + import torch + + cuda_version = getattr(torch.version, "cuda", None) + if cuda_version is None: + return 0, 0 + + major, minor, *_ = cuda_version.split(".") + return int(major), int(minor) + version = ctypes.c_int() - cudart = ctypes.CDLL("libcudart.so") cudart.cudaRuntimeGetVersion(ctypes.byref(version)) return version.value // 1000, (version.value % 1000) // 10