From ce5ee8662b98afcd0c251d5daee52b2beb26afa5 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 17 Feb 2026 19:45:49 +0000 Subject: [PATCH 1/3] init Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_permutation.py | 63 +- transformer_engine/pytorch/permutation.py | 1125 ++++++++++------- .../pytorch/quantized_tensor.py | 6 + 3 files changed, 752 insertions(+), 442 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index be1ff30472..80616d097e 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -227,6 +227,7 @@ def _test_permutation_index_map( num_out_tokens, with_probs, BENCHMARK=False, + use_torch_compile=False, ): if not with_probs and topK > 1: pytest.skip("Only permutations with topK=1 and without probabilities are supported.") @@ -298,9 +299,27 @@ def _test_permutation_index_map( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, indices, num_out_tokens, map_type="index" - ) + if use_torch_compile: + # Reset dynamo to avoid recompile limit across parametrized tests + torch._dynamo.reset() + # Disable donated buffers to allow retain_graph=True + import torch._functorch.config as functorch_config + old_donated_buffer = functorch_config.donated_buffer + functorch_config.donated_buffer = False + + # Create a wrapper function for torch.compile + def permute_wrapper(inp, idx, num_out, max_token): + return te_permute(inp, idx, num_out, max_token, map_type="index") + + # Compile with fullgraph=True + compiled_permute = torch.compile(permute_wrapper, fullgraph=True) + te_permute_output, row_id_map = compiled_permute( + te_permute_fwd_input, indices, num_out_tokens, -1 + ) + else: + te_permute_output, row_id_map = te_permute( + te_permute_fwd_input, indices, num_out_tokens, map_type="index" + ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -311,11 +330,25 @@ def _test_permutation_index_map( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" - ) + if use_torch_compile: + # Create a wrapper function for torch.compile + def unpermute_wrapper(inp, row_map, probs_val): + return te_unpermute(inp, row_map, probs_val, map_type="index") + + # Compile with fullgraph=True + compiled_unpermute = torch.compile(unpermute_wrapper, fullgraph=True) + te_unpermute_output = compiled_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs + ) + else: + te_unpermute_output = te_unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" + ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) + if use_torch_compile: + functorch_config.donated_buffer = old_donated_buffer + ################################################################################################################################### # # Results Check @@ -1647,6 +1680,11 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) +@pytest.mark.skipif( + torch.__version__ < "2", + reason="torch.compile not available - skipping torch.compile tests", +) def test_permutation_index_map( te_dtype, num_tokens, @@ -1654,7 +1692,10 @@ def test_permutation_index_map( hidden_size, topK, num_out_tokens, + use_torch_compile, ): + if use_torch_compile and torch.__version__ < "2": + pytest.skip("torch.compile not available") with_probs = True BENCHMARK = False @@ -1667,6 +1708,7 @@ def test_permutation_index_map( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1875,12 +1917,20 @@ def test_permutation_mask_map_fp8( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) +@pytest.mark.skipif( + torch.__version__ < "2", + reason="torch.compile not available - skipping torch.compile tests", +) def test_permutation_index_map_topk1_no_probs( te_dtype, num_tokens, num_expert, hidden_size, + use_torch_compile, ): + if use_torch_compile and torch.__version__ < "2": + pytest.skip("torch.compile not available") topK = 1 num_out_tokens = None with_probs = False @@ -1895,6 +1945,7 @@ def test_permutation_index_map_topk1_no_probs( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 5beeed1262..ff2ad693c6 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -22,495 +22,701 @@ ] -class _moe_permute_index_map(torch.autograd.Function): - """functional Permute with index router map""" +# Workspace state for moe_permute_index_map (module-level for compatibility) +_moe_permute_index_map_workspace = None +_moe_permute_index_map_max_expanded_token_num = 0 - workspace = None - max_expanded_token_num = 0 - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - index: torch.Tensor, - num_out_tokens: int, - max_token_num: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - # Empty input check - if not inp.numel(): - return inp, torch.tensor([], device=inp.device) +@torch.library.custom_op("te_moe::permute_index_map", mutates_args=[]) +def moe_permute_index_map_forward( + inp: torch.Tensor, + index: torch.Tensor, + num_out_tokens: int, + max_token_num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for MoE permute with index router map.""" + global _moe_permute_index_map_workspace, _moe_permute_index_map_max_expanded_token_num + + # Empty input check + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert index.is_cuda, "TransformerEngine needs CUDA." + # Shape check + assert inp.size(0) == index.size(0), "Permute not possible" + + # Data type check + dtype = TE_DType[inp.dtype] + if index.dtype != torch.int32: + warnings.warn( + f"The data type of the input `index` of Permute is {index.dtype}! " + "The recommended type is torch.int32." + ) + index = index.to(torch.int32) + + topK = index.size(1) + + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK + if _moe_permute_index_map_max_expanded_token_num < input_max_expanded_token_num: + _moe_permute_index_map_max_expanded_token_num = input_max_expanded_token_num + _moe_permute_index_map_workspace = [] + + permuted_act, row_id_map, _moe_permute_index_map_workspace = tex.moe_permute_fwd( + inp, + dtype, + index, + num_out_tokens, + _moe_permute_index_map_workspace, + _moe_permute_index_map_max_expanded_token_num, + ) - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." - # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" + return permuted_act, row_id_map - # Data type check - dtype = TE_DType[inp.dtype] - if index.dtype != torch.int32: - warnings.warn( - f"The data type of the input `index` of Permute is {index.dtype}! " - "The recommended type is torch.int32." - ) - index = index.to(torch.int32) - topK = index.size(1) +@moe_permute_index_map_forward.register_fake +def _moe_permute_index_map_fake( + inp: torch.Tensor, + index: torch.Tensor, + num_out_tokens: int, + max_token_num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference.""" + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) - input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num: - _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num - _moe_permute_index_map.workspace = [] + num_tokens = inp.shape[0] + topK = index.shape[1] if index.numel() > 0 else 1 - permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd( - inp, - dtype, - index, - num_out_tokens, - _moe_permute_index_map.workspace, - _moe_permute_index_map.max_expanded_token_num, - ) + # Infer output shape (see permutation.cpp line 55) + output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK - ctx.row_id_map = row_id_map - ctx.num_tokens = index.size(0) - ctx.topK = index.size(1) - return permuted_act, row_id_map + # row_id_map is 1D with size = num_tokens * topK (see permutation.cpp line 59-60) + fake_output = torch.empty( + (output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device + ) + fake_row_id_map = torch.empty( + (num_tokens * topK,), dtype=torch.int32, device=inp.device + ) - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - _, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - # Empty input check - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, None + return fake_output, fake_row_id_map - if not permuted_act_grad.is_contiguous(): - permuted_act_grad = permuted_act_grad.contiguous() - dtype = TE_DType[permuted_act_grad.dtype] - act_grad = None - if ctx.needs_input_grad[0]: - act_grad = tex.moe_permute_bwd( - permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK - ) +@torch.library.custom_op("te_moe::permute_index_map_bwd", mutates_args=[]) +def moe_permute_index_map_backward( + grad_permuted_act: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Backward pass for MoE permute with index router map.""" + if not grad_permuted_act.is_contiguous(): + grad_permuted_act = grad_permuted_act.contiguous() - return act_grad, None, None, None + dtype = TE_DType[grad_permuted_act.dtype] + act_grad = tex.moe_permute_bwd( + grad_permuted_act, dtype, row_id_map, torch.empty(0), num_tokens, topK + ) + return act_grad -class _moe_unpermute_index_map(torch.autograd.Function): - """functional Unpermute with index router map""" +@moe_permute_index_map_backward.register_fake +def _moe_permute_index_map_backward_fake( + grad_permuted_act: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Fake implementation for shape inference of backward.""" + return torch.empty( + (num_tokens, grad_permuted_act.shape[1]), + dtype=grad_permuted_act.dtype, + device=grad_permuted_act.device, + ) - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - row_id_map: torch.Tensor, - probs: torch.Tensor, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Empty input check - if not inp.numel(): - ctx.probs = probs - return inp - # None probs check - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." +def _moe_permute_index_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, index, num_out_tokens, max_token_num = inputs + permuted_act, row_id_map = output + ctx.save_for_backward(row_id_map) + ctx.num_tokens = index.size(0) if index.numel() > 0 else 0 + ctx.topK = index.size(1) if index.numel() > 0 else 1 - if probs.dtype != torch.float32: - warnings.warn( - f"The data type of the input `probs` of Unpermute is {probs.dtype}! " - "The recommended type is torch.float32." - ) - probs = probs.to(torch.float32) - num_tokens = probs.size(0) - topK = probs.size(1) - else: - num_tokens = row_id_map.size(0) - topK = 1 - probs = torch.empty(0) +def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_map): + """Backward pass wrapper that calls the custom backward op.""" + # Empty input check + if not grad_permuted_act.numel(): + return grad_permuted_act, None, None, None - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + (row_id_map,) = ctx.saved_tensors + act_grad = moe_permute_index_map_backward( + grad_permuted_act, row_id_map, ctx.num_tokens, ctx.topK + ) - # Data type check - dtype = TE_DType[inp.dtype] - if row_id_map.dtype != torch.int32: - warnings.warn( - f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " - "The recommended type is torch.int32." - ) - row_id_map = row_id_map.to(torch.int32) + return act_grad, None, None, None - unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) - ctx.save_for_backward(inp, row_id_map, probs) - return unpermuted_output +moe_permute_index_map_forward.register_autograd( + _moe_permute_index_map_backward_wrapper, + setup_context=_moe_permute_index_map_setup_context, +) - @staticmethod - def backward( - ctx, - unpermuted_act_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, None, torch.Tensor]: - # pylint: disable=missing-function-docstring - # Empty input check - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.probs - if not unpermuted_act_grad.is_contiguous(): - unpermuted_act_grad = unpermuted_act_grad.contiguous() +# ---------------------------------- Forward custom op ---------------------------------- - dtype = TE_DType[unpermuted_act_grad.dtype] - inp, row_id_map, probs = ctx.saved_tensors +@torch.library.custom_op("te_moe::unpermute_index_map_fwd", mutates_args=[]) +def moe_unpermute_index_map_forward( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Forward pass for MoE unpermute with index router map.""" + dtype = TE_DType[inp.dtype] + return tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) - act_grad = None - prob_grad = None - if ctx.needs_input_grad[0]: - act_grad, prob_grad = tex.moe_unpermute_bwd( - unpermuted_act_grad, inp, dtype, row_id_map, probs - ) - if not ctx.needs_input_grad[2]: - prob_grad = None - return act_grad, None, prob_grad +@moe_unpermute_index_map_forward.register_fake +def _moe_unpermute_index_map_forward_fake( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Fake implementation for shape inference.""" + # Output shape: (num_tokens, hidden_size) — see permutation.cpp line 95-97 + return torch.empty( + (num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device + ) -class _moe_permute_mask_map(torch.autograd.Function): - """functional Permute with mask router map""" +# ---------------------------------- Backward custom op ---------------------------------- - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - routing_map: torch.Tensor, - num_out_tokens: int, - probs: torch.Tensor, - pad_offsets: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - if not inp.numel(): - ctx.probs = probs - return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) +@torch.library.custom_op("te_moe::unpermute_index_map_bwd", mutates_args=[]) +def moe_unpermute_index_map_backward( + unpermuted_act_grad: torch.Tensor, + fwd_input: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE unpermute with index router map.""" + dtype = TE_DType[unpermuted_act_grad.dtype] + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, fwd_input, dtype, row_id_map, probs + ) + return act_grad, prob_grad - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." - assert inp.size(0) == routing_map.size(0), "Permute not possible" - num_tokens, hidden_size = inp.size() - num_experts = routing_map.size(1) - assert ( - num_out_tokens is not None - ), "num_out_tokens must be provided to the fused permute function." +@moe_unpermute_index_map_backward.register_fake +def _moe_unpermute_index_map_backward_fake( + unpermuted_act_grad: torch.Tensor, + fwd_input: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference of backward.""" + # act_grad shape: (fwd_input.size(0), hidden_size) — see permutation.cpp line 127-129 + # prob_grad shape: (num_tokens, topK) — see permutation.cpp line 130-131 + topK = probs.size(1) if probs.numel() > 0 else 1 + num_tokens = probs.size(0) if probs.numel() > 0 else row_id_map.size(0) + act_grad = torch.empty( + (fwd_input.size(0), unpermuted_act_grad.shape[1]), + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, + ) + prob_grad = torch.empty( + (num_tokens, topK), dtype=torch.float32, device=unpermuted_act_grad.device + ) + return act_grad, prob_grad - row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) - fp8 = isinstance(inp, QuantizedTensor) - per_tensor_recipe = isinstance(inp, Float8Tensor) - blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) - mxfp8_recipe = isinstance(inp, MXFP8Tensor) +# ---------------------------------- Autograd glue ---------------------------------- - if fp8: - fp8_dtype = inp._fp8_dtype - fake_dtype = inp.dtype - # blockwise scaling - if blockwise_recipe: - fp8_scale = inp._rowwise_scale_inv.T.contiguous() - scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - inp = inp._rowwise_data - # mxfp8 scaling - elif mxfp8_recipe: - fp8_scale = inp._rowwise_scale_inv.contiguous() - scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - inp = inp._rowwise_data - # per-tensor scaling - elif per_tensor_recipe: - # Kernel does not need scale in per-tensor scaling - fp8_scale = None - scale_hidden_dim = None - fp8_scale_inv = inp._scale_inv - inp = inp._data - else: - raise ValueError("Unsupported FP8 recipe") - else: +def _moe_unpermute_index_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, row_id_map, probs, num_tokens, topK = inputs + ctx.save_for_backward(inp, row_id_map, probs) + ctx.needs_probs_grad = probs.requires_grad if probs.numel() > 0 else False + + +def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad): + """Backward pass wrapper that calls the custom backward op.""" + if not unpermuted_act_grad.numel(): + return unpermuted_act_grad, None, None, None, None + + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() + + inp, row_id_map, probs = ctx.saved_tensors + + act_grad, prob_grad = moe_unpermute_index_map_backward( + unpermuted_act_grad, inp, row_id_map, probs + ) + + if not ctx.needs_probs_grad: + prob_grad = None + + return act_grad, None, prob_grad, None, None + + +moe_unpermute_index_map_forward.register_autograd( + _moe_unpermute_index_map_backward_wrapper, + setup_context=_moe_unpermute_index_map_setup_context, +) + + +# ===================== _moe_permute_mask_map custom ops ===================== + +@torch.library.custom_op("te_moe::permute_mask_map_fwd", mutates_args=[]) +def moe_permute_mask_map_forward( + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + probs: Optional[torch.Tensor], + pad_offsets: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for MoE permute with mask router map.""" + # Empty input check + if not inp.numel(): + return ( + inp, + torch.tensor([], device=inp.device), + torch.tensor([], device=inp.device), + ) + + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert routing_map.is_cuda, "TransformerEngine needs CUDA." + if probs is not None: + assert probs.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + assert inp.size(0) == routing_map.size(0), "Permute not possible" + assert num_out_tokens is not None, "num_out_tokens must be provided to the fused permute function." + + num_tokens, hidden_size = inp.size() + num_experts = routing_map.size(1) + + row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) + + # FP8 handling + fp8 = isinstance(inp, QuantizedTensor) + per_tensor_recipe = isinstance(inp, Float8Tensor) + blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(inp, MXFP8Tensor) + + if fp8: + fp8_dtype = inp._fp8_dtype + fake_dtype = inp.dtype + if blockwise_recipe: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + elif mxfp8_recipe: + fp8_scale = inp._rowwise_scale_inv.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + elif per_tensor_recipe: fp8_scale = None - fp8_dtype = None scale_hidden_dim = None + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + raise ValueError("Unsupported FP8 recipe") + else: + fp8_scale = None + fp8_dtype = None + scale_hidden_dim = None - output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( - inp, - row_id_map, - probs, - fp8_scale, - pad_offsets, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - scale_hidden_dim, + output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( + inp, row_id_map, probs, fp8_scale, pad_offsets, + num_tokens, num_experts, num_out_tokens, hidden_size, scale_hidden_dim, + ) + + if fp8: + if per_tensor_recipe: + output = Float8Tensor( + data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, + shape=output.shape, dtype=fake_dtype, + ) + elif blockwise_recipe: + output = Float8BlockwiseQTensor( + shape=output.shape, dtype=fake_dtype, rowwise_data=output, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False, + requires_grad=output.requires_grad, + ) + elif mxfp8_recipe: + output = MXFP8Tensor( + shape=output.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype, + rowwise_data=output, rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, columnwise_scale_inv=None, + quantizer=None, requires_grad=output.requires_grad, + with_gemm_swizzled_scales=False, + ) + + # If permuted_probs is None, return empty tensor (custom ops need concrete tensors) + if permuted_probs is None: + permuted_probs = torch.empty(0, device=inp.device) + + + return output, row_id_map, permuted_probs + + +@moe_permute_mask_map_forward.register_fake +def _moe_permute_mask_map_forward_fake( + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + probs: Optional[torch.Tensor], + pad_offsets: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference.""" + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] + num_experts = routing_map.shape[1] + # row_id_map: (num_tokens, num_experts * 2 + 1) — see triton make_row_id_map + fake_output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + fake_row_id_map = torch.empty( + (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=inp.device + ) + if probs is not None: + fake_permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device=inp.device) + else: + fake_permuted_probs = torch.empty(0, device=inp.device) + return fake_output, fake_row_id_map, fake_permuted_probs + + +@torch.library.custom_op("te_moe::permute_mask_map_bwd", mutates_args=[]) +def moe_permute_mask_map_backward( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE permute with mask router map.""" + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( + permuted_act_grad, row_id_map, None, permuted_probs_grad, pad_offsets, + num_tokens, num_experts, hidden_size, + ) + if probs_grad is None: + probs_grad = torch.empty(0, device=permuted_act_grad.device) + return act_grad, probs_grad + + +@moe_permute_mask_map_backward.register_fake +def _moe_permute_mask_map_backward_fake( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference.""" + act_grad = torch.empty( + (num_tokens, hidden_size), dtype=permuted_act_grad.dtype, device=permuted_act_grad.device + ) + if permuted_probs_grad is not None: + probs_grad = torch.empty( + (num_tokens, num_experts), dtype=permuted_probs_grad.dtype, + device=permuted_act_grad.device, ) + else: + probs_grad = torch.empty(0, device=permuted_act_grad.device) + return act_grad, probs_grad - if fp8: - if per_tensor_recipe: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) - elif blockwise_recipe: - output = Float8BlockwiseQTensor( - shape=output.shape, - dtype=fake_dtype, - rowwise_data=output, - rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, - quantizer=None, - is_2D_scaled=False, - requires_grad=output.requires_grad, - ) - elif mxfp8_recipe: - output = MXFP8Tensor( - shape=output.shape, - dtype=fake_dtype, - fp8_dtype=fp8_dtype, - rowwise_data=output, - rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - quantizer=None, - requires_grad=output.requires_grad, - with_gemm_swizzled_scales=False, - ) +def _moe_permute_mask_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, routing_map, num_out_tokens, probs, pad_offsets = inputs + output_tensor, row_id_map, permuted_probs = output + ctx.empty_input = not inp.numel() + if ctx.empty_input and probs is not None: + ctx.save_for_backward(row_id_map, pad_offsets, probs) + else: ctx.save_for_backward(row_id_map, pad_offsets) - ctx.num_experts = num_experts - ctx.num_tokens = num_tokens - ctx.hidden_size = hidden_size - return output, row_id_map, permuted_probs + ctx.num_experts = routing_map.size(1) if routing_map.numel() > 0 else 0 + ctx.num_tokens = inp.size(0) + ctx.hidden_size = inp.size(1) if inp.numel() > 0 else 0 + ctx.needs_probs_grad = probs is not None and probs.requires_grad - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - _, - permuted_probs_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, ctx.probs, None - act_grad = None +def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, grad_permuted_probs): + """Backward wrapper calling the custom backward op.""" + if ctx.empty_input: + if ctx.needs_probs_grad: + _, _, probs = ctx.saved_tensors + return grad_output, None, None, probs, None + return grad_output, None, None, None, None + + assert not isinstance( + grad_output, QuantizedTensor + ), "The backward of moe_permute does not support FP8." + + row_id_map, pad_offsets = ctx.saved_tensors + + # Pass permuted_probs_grad only if it has content + probs_grad_input = grad_permuted_probs if grad_permuted_probs.numel() > 0 else None + + act_grad, probs_grad = moe_permute_mask_map_backward( + grad_output, probs_grad_input, row_id_map, pad_offsets, + ctx.num_tokens, ctx.num_experts, ctx.hidden_size, + ) + + if not ctx.needs_probs_grad or probs_grad.numel() == 0: probs_grad = None - if ctx.needs_input_grad[0]: - row_id_map, pad_offsets = ctx.saved_tensors - assert not isinstance( - permuted_act_grad, QuantizedTensor - ), "The backward of moe_permute does not support FP8." - act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( - permuted_act_grad, - row_id_map, - None, - permuted_probs_grad, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.hidden_size, - ) - if not ctx.needs_input_grad[3]: - probs_grad = None - return act_grad, None, None, probs_grad, None + return act_grad, None, None, probs_grad, None -class _moe_unpermute_mask_map(torch.autograd.Function): - """functional Unpermute with mask router map""" - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - row_id_map: torch.Tensor, - merging_probs: Optional[torch.Tensor], - restore_shape: Optional[torch.Size], - pad_offsets: Optional[torch.Tensor], - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - if not inp.numel(): - ctx.merging_probs = merging_probs - return inp +moe_permute_mask_map_forward.register_autograd( + _moe_permute_mask_map_backward_wrapper, + setup_context=_moe_permute_mask_map_setup_context, +) - if restore_shape is None: - restore_shape = inp.shape - num_tokens, hidden_size = restore_shape - num_experts = (row_id_map.size(1) - 1) // 2 - with_probs = merging_probs is not None - if with_probs: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." +# ===================== _moe_unpermute_mask_map custom ops ===================== - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." +@torch.library.custom_op("te_moe::unpermute_mask_map_fwd", mutates_args=[]) +def moe_unpermute_mask_map_forward( + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, + pad_offsets: Optional[torch.Tensor], +) -> torch.Tensor: + """Forward pass for MoE unpermute with mask router map.""" + # Empty input check + if not inp.numel(): + return inp + + assert not isinstance( + inp, QuantizedTensor + ), "The forward of moe_unpermute does not support FP8." + unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( + inp, row_id_map, merging_probs, None, pad_offsets, + num_tokens, num_experts, hidden_size, + ) + return unpermuted_output - assert not isinstance( - inp, QuantizedTensor - ), "The forward of moe_unpermute does not support FP8." - unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( - inp, - row_id_map, - merging_probs, - None, - pad_offsets, - num_tokens, - num_experts, - hidden_size, - ) - if with_probs: - ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) +@moe_unpermute_mask_map_forward.register_fake +def _moe_unpermute_mask_map_forward_fake( + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, + pad_offsets: Optional[torch.Tensor], +) -> torch.Tensor: + """Fake implementation for shape inference.""" + return torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + + +@torch.library.custom_op("te_moe::unpermute_mask_map_bwd_with_probs", mutates_args=[]) +def moe_unpermute_mask_map_backward_with_probs( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + merging_probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE unpermute with merging probs.""" + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( + unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets, + num_tokens, num_experts, num_permuted_tokens, hidden_size, + ) + return act_grad, probs_grad + + +@moe_unpermute_mask_map_backward_with_probs.register_fake +def _moe_unpermute_mask_map_bwd_with_probs_fake( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + merging_probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference with merging probs.""" + act_grad = torch.empty( + (num_permuted_tokens, hidden_size), + dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device, + ) + probs_grad = torch.empty( + (num_tokens, num_experts), + dtype=merging_probs.dtype, device=unpermuted_act_grad.device, + ) + return act_grad, probs_grad + + +@torch.library.custom_op("te_moe::unpermute_mask_map_bwd_no_probs", mutates_args=[]) +def moe_unpermute_mask_map_backward_no_probs( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> torch.Tensor: + """Backward pass for MoE unpermute without merging probs (permute grad back).""" + # FP8 handling + fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) + per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) + blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) + + if fp8: + fp8_dtype = unpermuted_act_grad._fp8_dtype + fake_dtype = unpermuted_act_grad.dtype + if per_tensor_recipe: + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + elif blockwise_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + elif mxfp8_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" else: - ctx.save_for_backward(row_id_map, pad_offsets) - ctx.num_experts = num_experts - ctx.num_tokens = num_tokens - ctx.num_permuted_tokens = inp.size(0) - ctx.hidden_size = hidden_size - ctx.with_probs = with_probs - return unpermuted_output + raise ValueError("Unsupported FP8 recipe") + else: + scale_hidden_dim = None + fp8_dtype = None + fp8_scale = None - @staticmethod - def backward(ctx, unpermuted_act_grad): - # pylint: disable=missing-function-docstring - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.merging_probs, None, None + act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( + unpermuted_act_grad, row_id_map, None, fp8_scale, pad_offsets, + num_tokens, num_experts, num_permuted_tokens, hidden_size, scale_hidden_dim, + ) - act_grad = None + if fp8: + if per_tensor_recipe: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, dtype=fake_dtype, + ) + elif blockwise_recipe: + act_grad = Float8BlockwiseQTensor( + shape=act_grad.shape, dtype=fake_dtype, rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False, + requires_grad=act_grad.requires_grad, + ) + elif mxfp8_recipe: + act_grad = MXFP8Tensor( + shape=act_grad.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype, + rowwise_data=act_grad, rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, columnwise_scale_inv=None, + quantizer=None, requires_grad=act_grad.requires_grad, + with_gemm_swizzled_scales=False, + ) + + return act_grad + + +@moe_unpermute_mask_map_backward_no_probs.register_fake +def _moe_unpermute_mask_map_bwd_no_probs_fake( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> torch.Tensor: + """Fake for backward shape inference without probs.""" + return torch.empty( + (num_permuted_tokens, hidden_size), + dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device, + ) + + +def _moe_unpermute_mask_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, row_id_map, merging_probs, num_tokens, num_experts, hidden_size, pad_offsets = inputs + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.num_permuted_tokens = inp.size(0) + ctx.hidden_size = hidden_size + ctx.with_probs = merging_probs is not None + ctx.empty_input = not inp.numel() + if ctx.with_probs: + ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) + ctx.needs_probs_grad = merging_probs.requires_grad + else: + ctx.save_for_backward(row_id_map, pad_offsets) + ctx.needs_probs_grad = False + + +def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): + """Backward wrapper calling the appropriate custom backward op.""" + if ctx.empty_input: + # Return merging_probs as its own grad for empty input (matches original behavior) + if ctx.with_probs: + _, _, merging_probs, _ = ctx.saved_tensors + return unpermuted_act_grad, None, merging_probs, None, None, None, None + return unpermuted_act_grad, None, None, None, None, None, None + + act_grad = None + probs_grad = None + + if ctx.with_probs: + fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors + assert not isinstance( + unpermuted_act_grad, QuantizedTensor + ), "The backward of moe_unpermute with merging probs does not support FP8." + act_grad, probs_grad = moe_unpermute_mask_map_backward_with_probs( + unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets, + ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + ) + else: + row_id_map, pad_offsets = ctx.saved_tensors + act_grad = moe_unpermute_mask_map_backward_no_probs( + unpermuted_act_grad, row_id_map, pad_offsets, + ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + ) + + if not ctx.needs_probs_grad: probs_grad = None - if ctx.needs_input_grad[0]: - if ctx.with_probs: - fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors - else: - row_id_map, pad_offsets = ctx.saved_tensors - fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) - per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) - blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) - mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) + return act_grad, None, probs_grad, None, None, None, None - if fp8: - fp8_dtype = unpermuted_act_grad._fp8_dtype - fake_dtype = unpermuted_act_grad.dtype - # per-tensor scaling - if per_tensor_recipe: - # Kernel does not need scale in per-tensor scaling - fp8_scale = None - scale_hidden_dim = None - fp8_scale_inv = unpermuted_act_grad._scale_inv - unpermuted_act_grad = unpermuted_act_grad._data - # blockwise scaling - elif blockwise_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() - unpermuted_act_grad = unpermuted_act_grad._rowwise_data - scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - # mxfp8 scaling - elif mxfp8_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() - unpermuted_act_grad = unpermuted_act_grad._rowwise_data - scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - else: - raise ValueError("Unsupported FP8 recipe") - else: - scale_hidden_dim = None - fp8_dtype = None - fp8_scale = None - - if ctx.with_probs: - assert ( - not fp8 - ), "The backward of moe_unpermute with merging probs does not support FP8." - act_grad, probs_grad = ( - triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( - unpermuted_act_grad, - row_id_map, - fwd_input, - merging_probs, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - ) - ) - else: - act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( - unpermuted_act_grad, - row_id_map, - None, - fp8_scale, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - scale_hidden_dim, - ) - if fp8: - if per_tensor_recipe: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) - elif blockwise_recipe: - act_grad = Float8BlockwiseQTensor( - shape=act_grad.shape, - dtype=fake_dtype, - rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, - quantizer=None, - is_2D_scaled=False, - requires_grad=act_grad.requires_grad, - ) - elif mxfp8_recipe: - act_grad = MXFP8Tensor( - shape=act_grad.shape, - dtype=fake_dtype, - fp8_dtype=fp8_dtype, - rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - quantizer=None, - requires_grad=act_grad.requires_grad, - with_gemm_swizzled_scales=False, - ) - - if not ctx.needs_input_grad[2]: - probs_grad = None - return act_grad, None, probs_grad, None, None +moe_unpermute_mask_map_forward.register_autograd( + _moe_unpermute_mask_map_backward_wrapper, + setup_context=_moe_unpermute_mask_map_setup_context, +) def moe_permute( @@ -548,9 +754,9 @@ def moe_permute( Refer to `routing_map` for more details. """ if map_type == "index": - return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) + return moe_permute_index_map_forward(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply( + output, row_id_map, _ = moe_permute_mask_map_forward( inp, routing_map, num_out_tokens, None, None ) return output, row_id_map @@ -584,7 +790,7 @@ def moe_permute_with_probs( The effective output token count, representing the number of tokens not dropped. By default, set to '-1', meaning no tokens are dropped. """ - output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + output, row_id_map, permuted_probs = moe_permute_mask_map_forward( inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map @@ -640,7 +846,7 @@ def moe_permute_and_pad_with_probs( [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] ) - output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + output, row_id_map, permuted_probs = moe_permute_mask_map_forward( inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets ) return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert @@ -690,10 +896,57 @@ def moe_unpermute( warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.") merging_probs = probs if map_type == "index": - return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) + # Empty input check + if not inp.numel(): + return inp + + # Normalize probs + if merging_probs is not None: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + if merging_probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {merging_probs.dtype}! " + "The recommended type is torch.float32." + ) + merging_probs = merging_probs.to(torch.float32) + num_tokens = merging_probs.size(0) + topK = merging_probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + merging_probs = torch.empty(0, device=inp.device) + + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) + + return moe_unpermute_index_map_forward(inp, row_id_map, merging_probs, num_tokens, topK) if map_type == "mask": - return _moe_unpermute_mask_map.apply( - inp, row_id_map, merging_probs, restore_shape, pad_offsets + if restore_shape is None: + restore_shape = inp.shape + num_tokens, hidden_size = restore_shape + num_experts = (row_id_map.size(1) - 1) // 2 if row_id_map.numel() > 0 else 0 + + if not inp.numel(): + # Pass through custom op even for empty input so probs stays in the graph + pass + else: + if merging_probs is not None: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + + return moe_unpermute_mask_map_forward( + inp, row_id_map, merging_probs, + num_tokens, num_experts, hidden_size, pad_offsets, ) raise ValueError("map_type should be one of 'mask' or 'index'") diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d78677bc83..678b884812 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -516,6 +516,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return func(t) return False # Or error out? + # Pass through te_moe custom ops without unwrapping + if hasattr(func, "namespace") and func.namespace == "te_moe": + if kwargs is None: + kwargs = {} + return super().__torch_dispatch__(func, types, args, kwargs) + def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): return arg.dequantize(dtype=arg.dtype) From 8159d263185e157415c56ef6592821737533bb3a Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 18 Feb 2026 17:31:17 +0000 Subject: [PATCH 2/3] work finished Signed-off-by: Pawel Gadzinski --- tests/pytorch/test_permutation.py | 173 ++++--- transformer_engine/pytorch/permutation.py | 432 ++++++++++-------- .../pytorch/quantized_tensor.py | 10 +- 3 files changed, 340 insertions(+), 275 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index 80616d097e..bc5957a18b 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -218,6 +218,16 @@ def backward_wrapper( return act.backward(backward_input, retain_graph=retain_graph) +def _maybe_compile(fn, use_torch_compile): + """Wrap fn with torch.compile(fullgraph=True) if requested.""" + if use_torch_compile: + torch._dynamo.reset() + import torch._functorch.config as functorch_config + functorch_config.donated_buffer = False + return torch.compile(fn, fullgraph=True) + return fn + + def _test_permutation_index_map( te_dtype, num_tokens, @@ -299,27 +309,13 @@ def _test_permutation_index_map( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = pytorch_permute_bwd_input.detach() - if use_torch_compile: - # Reset dynamo to avoid recompile limit across parametrized tests - torch._dynamo.reset() - # Disable donated buffers to allow retain_graph=True - import torch._functorch.config as functorch_config - old_donated_buffer = functorch_config.donated_buffer - functorch_config.donated_buffer = False - - # Create a wrapper function for torch.compile - def permute_wrapper(inp, idx, num_out, max_token): - return te_permute(inp, idx, num_out, max_token, map_type="index") - - # Compile with fullgraph=True - compiled_permute = torch.compile(permute_wrapper, fullgraph=True) - te_permute_output, row_id_map = compiled_permute( - te_permute_fwd_input, indices, num_out_tokens, -1 - ) - else: - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, indices, num_out_tokens, map_type="index" - ) + _permute = _maybe_compile( + lambda inp, idx, num_out, max_token: te_permute(inp, idx, num_out, max_token, map_type="index"), + use_torch_compile, + ) + te_permute_output, row_id_map = _permute( + te_permute_fwd_input, indices, num_out_tokens, -1 + ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -330,25 +326,15 @@ def permute_wrapper(inp, idx, num_out, max_token): te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() - if use_torch_compile: - # Create a wrapper function for torch.compile - def unpermute_wrapper(inp, row_map, probs_val): - return te_unpermute(inp, row_map, probs_val, map_type="index") - - # Compile with fullgraph=True - compiled_unpermute = torch.compile(unpermute_wrapper, fullgraph=True) - te_unpermute_output = compiled_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs - ) - else: - te_unpermute_output = te_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" - ) + _unpermute = _maybe_compile( + lambda inp, row_map, probs_val: te_unpermute(inp, row_map, probs_val, map_type="index"), + use_torch_compile, + ) + te_unpermute_output = _unpermute( + te_unpermute_fwd_input, row_id_map, te_probs + ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) - if use_torch_compile: - functorch_config.donated_buffer = old_donated_buffer - ################################################################################################################################### # # Results Check @@ -477,6 +463,7 @@ def _test_permutation_mask_map( num_out_tokens, with_probs, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -547,8 +534,12 @@ def _test_permutation_mask_map( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + _permute = _maybe_compile( + lambda inp, rmap, n_out: te_permute(inp, rmap, num_out_tokens=n_out, map_type="mask"), + use_torch_compile, + ) + te_permute_output, row_id_map = _permute( + te_permute_fwd_input, routing_map, num_out_tokens ) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) @@ -560,8 +551,12 @@ def _test_permutation_mask_map( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + _unpermute = _maybe_compile( + lambda inp, row_map, p, rs: te_unpermute(inp, row_map, p, rs, map_type="mask"), + use_torch_compile, + ) + te_unpermute_output = _unpermute( + te_unpermute_fwd_input, row_id_map, te_probs, restore_shape ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) @@ -699,6 +694,7 @@ def _test_permutation_and_padding_mask_map( with_merging_probs=False, align_size=16, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -990,6 +986,7 @@ def _test_permutation_and_padding_with_merging_probs( num_out_tokens, align_size=16, BENCHMARK=False, + use_torch_compile=False, ): """ Test the combination of merging_probs AND pad_offsets together in moe_unpermute. @@ -1324,6 +1321,7 @@ def _test_moe_chunk_sort( tp_size, hidden_size, BENCHMARK=False, + use_torch_compile=False, ): print( "chunk permute:" @@ -1373,7 +1371,11 @@ def _test_moe_chunk_sort( te_fwd_input.requires_grad_(True) te_bwd_input = pytorch_bwd_input.detach() - te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + _sort = _maybe_compile( + lambda inp, ss, si: te_sort_chunks_by_index(inp, ss, si), + use_torch_compile, + ) + te_output = _sort(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) te_output.backward(te_bwd_input, retain_graph=True) ################################################################################################################################### @@ -1448,6 +1450,7 @@ def _test_permutation_mask_map_alongside_probs( num_out_tokens, tp_size, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -1543,30 +1546,20 @@ def _test_permutation_mask_map_alongside_probs( te_probs = probs.detach() te_probs.requires_grad_(True) - te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( - te_permute_fwd_input, - te_probs, - routing_map, - num_out_tokens=num_out_tokens, - ) - - te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( - te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda - ) - - te_permute_output_dtype = te_permute_output.dtype - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) - - te_permute_output = te_sort_chunks_by_index( - te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda - ) - - te_unpermute_output = te_unpermute( - te_permute_output, - row_id_map, - restore_shape=restore_shape, - map_type="mask", + def _alongside_probs_fn(fwd_inp, t_probs, rmap, ss1, si1, ss2, si2): + out, pprobs, rid = te_permute_with_probs(fwd_inp, t_probs, rmap, num_out_tokens=num_out_tokens) + out, pprobs = te_sort_chunks_by_index_with_probs(out, pprobs, ss1, si1) + out_dtype = out.dtype + out = out * pprobs.unsqueeze(-1) + out = out.to(dtype=out_dtype) + out = te_sort_chunks_by_index(out, ss2, si2) + out = te_unpermute(out, rid, restore_shape=restore_shape, map_type="mask") + return out + + _fn = _maybe_compile(_alongside_probs_fn, use_torch_compile) + te_unpermute_output = _fn( + te_permute_fwd_input, te_probs, routing_map, + split_sizes_cuda, sorted_idxs_cuda, split_sizes_2_cuda, sorted_idxs_2_cuda, ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) @@ -1696,6 +1689,8 @@ def test_permutation_index_map( ): if use_torch_compile and torch.__version__ < "2": pytest.skip("torch.compile not available") + if use_torch_compile and (num_expert != 7 or topK != 2): + pytest.skip("torch.compile tested with single config only") with_probs = True BENCHMARK = False @@ -1718,6 +1713,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map( te_dtype, num_tokens, @@ -1725,7 +1721,10 @@ def test_permutation_mask_map( hidden_size, topK, num_out_tokens, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or topK != 2): + pytest.skip("torch.compile tested with single config only") with_probs = True BENCHMARK = False @@ -1738,6 +1737,7 @@ def test_permutation_mask_map( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1753,6 +1753,7 @@ def test_permutation_mask_map( ], ) @pytest.mark.parametrize("with_merging_probs", [True, False]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_and_padding_mask_map( te_dtype, num_tokens, @@ -1761,7 +1762,10 @@ def test_permutation_and_padding_mask_map( topK, num_out_tokens, with_merging_probs, + use_torch_compile, ): + if use_torch_compile and (num_expert != 8 or topK != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_permutation_and_padding_mask_map( @@ -1773,6 +1777,7 @@ def test_permutation_and_padding_mask_map( num_out_tokens=num_out_tokens, with_merging_probs=with_merging_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1787,6 +1792,7 @@ def test_permutation_and_padding_mask_map( (4096, 512, 9216, 8), ], ) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_and_padding_with_merging_probs( te_dtype, num_tokens, @@ -1794,8 +1800,11 @@ def test_permutation_and_padding_with_merging_probs( hidden_size, topK, num_out_tokens, + use_torch_compile, ): """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets.""" + if use_torch_compile and (num_expert != 8 or topK != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_permutation_and_padding_with_merging_probs( @@ -1806,11 +1815,13 @@ def test_permutation_and_padding_with_merging_probs( topK=topK, num_out_tokens=num_out_tokens, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_permutation_mask_map_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_permutation_mask_map_empty_input(te_dtype, use_torch_compile): with_probs = True BENCHMARK = False @@ -1823,6 +1834,7 @@ def test_permutation_mask_map_empty_input(te_dtype): num_out_tokens=0, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1833,6 +1845,7 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map_alongside_probs( te_dtype, num_tokens, @@ -1841,7 +1854,10 @@ def test_permutation_mask_map_alongside_probs( topK, num_out_tokens, tp_size, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or topK != 2 or tp_size != 1): + pytest.skip("torch.compile tested with single config only") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1850,11 +1866,13 @@ def test_permutation_mask_map_alongside_probs( topK=topK, num_out_tokens=num_out_tokens, tp_size=tp_size, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_permutation_mask_map_alongside_probs_empty_input(te_dtype, use_torch_compile): _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, num_tokens=0, @@ -1863,6 +1881,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): topK=2, num_out_tokens=0, tp_size=2, + use_torch_compile=use_torch_compile, ) @@ -1931,6 +1950,8 @@ def test_permutation_index_map_topk1_no_probs( ): if use_torch_compile and torch.__version__ < "2": pytest.skip("torch.compile not available") + if use_torch_compile and num_expert != 7: + pytest.skip("torch.compile tested with single config only") topK = 1 num_out_tokens = None with_probs = False @@ -1953,12 +1974,16 @@ def test_permutation_index_map_topk1_no_probs( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map_topk1_no_probs( te_dtype, num_tokens, num_expert, hidden_size, + use_torch_compile, ): + if use_torch_compile and num_expert != 7: + pytest.skip("torch.compile tested with single config only") topK = 1 num_out_tokens = None with_probs = False @@ -1973,6 +1998,7 @@ def test_permutation_mask_map_topk1_no_probs( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1981,13 +2007,17 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("tp_size", [2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_chunk_permutation( te_dtype, num_tokens, num_expert, tp_size, hidden_size, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or tp_size != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_moe_chunk_sort( @@ -1997,11 +2027,13 @@ def test_chunk_permutation( tp_size=tp_size, hidden_size=hidden_size, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_chunk_permutation_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_chunk_permutation_empty_input(te_dtype, use_torch_compile): BENCHMARK = False _test_moe_chunk_sort( @@ -2011,6 +2043,7 @@ def test_chunk_permutation_empty_input(te_dtype): tp_size=2, hidden_size=4096, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index ff2ad693c6..7d2cb8f6f9 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -7,10 +7,16 @@ from typing import Optional, Tuple import torch +# Allow warnings.warn inside torch.compile without graph breaks +torch._dynamo.config.reorderable_logging_functions.add(warnings.warn) + import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensor, + _quantized_tensor_passthrough_ops, +) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor @@ -22,7 +28,9 @@ ] -# Workspace state for moe_permute_index_map (module-level for compatibility) +# ===================== _moe_permute_index_map custom ops ===================== + +# Workspace state for moe_permute_index_map _moe_permute_index_map_workspace = None _moe_permute_index_map_max_expanded_token_num = 0 @@ -37,24 +45,7 @@ def moe_permute_index_map_forward( """Forward pass for MoE permute with index router map.""" global _moe_permute_index_map_workspace, _moe_permute_index_map_max_expanded_token_num - # Empty input check - if not inp.numel(): - return inp, torch.tensor([], device=inp.device) - - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." - # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" - - # Data type check dtype = TE_DType[inp.dtype] - if index.dtype != torch.int32: - warnings.warn( - f"The data type of the input `index` of Permute is {index.dtype}! " - "The recommended type is torch.int32." - ) - index = index.to(torch.int32) topK = index.size(1) @@ -83,16 +74,13 @@ def _moe_permute_index_map_fake( max_token_num: int, ) -> Tuple[torch.Tensor, torch.Tensor]: """Fake implementation for shape inference.""" - if not inp.numel(): - return inp, torch.tensor([], device=inp.device) - num_tokens = inp.shape[0] - topK = index.shape[1] if index.numel() > 0 else 1 + topK = index.shape[1] - # Infer output shape (see permutation.cpp line 55) + # Infer output shape output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK - # row_id_map is 1D with size = num_tokens * topK (see permutation.cpp line 59-60) + # row_id_map is 1D with size = num_tokens * topK fake_output = torch.empty( (output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device ) @@ -111,9 +99,6 @@ def moe_permute_index_map_backward( topK: int, ) -> torch.Tensor: """Backward pass for MoE permute with index router map.""" - if not grad_permuted_act.is_contiguous(): - grad_permuted_act = grad_permuted_act.contiguous() - dtype = TE_DType[grad_permuted_act.dtype] act_grad = tex.moe_permute_bwd( grad_permuted_act, dtype, row_id_map, torch.empty(0), num_tokens, topK @@ -141,18 +126,17 @@ def _moe_permute_index_map_setup_context(ctx, inputs, output): inp, index, num_out_tokens, max_token_num = inputs permuted_act, row_id_map = output ctx.save_for_backward(row_id_map) - ctx.num_tokens = index.size(0) if index.numel() > 0 else 0 - ctx.topK = index.size(1) if index.numel() > 0 else 1 + ctx.num_tokens = index.size(0) + ctx.topK = index.size(1) def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_map): """Backward pass wrapper that calls the custom backward op.""" - # Empty input check - if not grad_permuted_act.numel(): - return grad_permuted_act, None, None, None + if not grad_permuted_act.is_contiguous(): + grad_permuted_act = grad_permuted_act.contiguous() (row_id_map,) = ctx.saved_tensors - act_grad = moe_permute_index_map_backward( + act_grad = torch.ops.te_moe.permute_index_map_bwd( grad_permuted_act, row_id_map, ctx.num_tokens, ctx.topK ) @@ -165,7 +149,7 @@ def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_ ) -# ---------------------------------- Forward custom op ---------------------------------- +# ===================== _moe_unpermute_index_map custom ops ===================== @torch.library.custom_op("te_moe::unpermute_index_map_fwd", mutates_args=[]) def moe_unpermute_index_map_forward( @@ -189,14 +173,12 @@ def _moe_unpermute_index_map_forward_fake( topK: int, ) -> torch.Tensor: """Fake implementation for shape inference.""" - # Output shape: (num_tokens, hidden_size) — see permutation.cpp line 95-97 + # Output shape: (num_tokens, hidden_size) return torch.empty( (num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device ) -# ---------------------------------- Backward custom op ---------------------------------- - @torch.library.custom_op("te_moe::unpermute_index_map_bwd", mutates_args=[]) def moe_unpermute_index_map_backward( unpermuted_act_grad: torch.Tensor, @@ -220,8 +202,8 @@ def _moe_unpermute_index_map_backward_fake( probs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: """Fake implementation for shape inference of backward.""" - # act_grad shape: (fwd_input.size(0), hidden_size) — see permutation.cpp line 127-129 - # prob_grad shape: (num_tokens, topK) — see permutation.cpp line 130-131 + # act_grad shape: (fwd_input.size(0), hidden_size) + # prob_grad shape: (num_tokens, topK) topK = probs.size(1) if probs.numel() > 0 else 1 num_tokens = probs.size(0) if probs.numel() > 0 else row_id_map.size(0) act_grad = torch.empty( @@ -235,26 +217,22 @@ def _moe_unpermute_index_map_backward_fake( return act_grad, prob_grad -# ---------------------------------- Autograd glue ---------------------------------- def _moe_unpermute_index_map_setup_context(ctx, inputs, output): """Save context for backward pass.""" inp, row_id_map, probs, num_tokens, topK = inputs ctx.save_for_backward(inp, row_id_map, probs) - ctx.needs_probs_grad = probs.requires_grad if probs.numel() > 0 else False + ctx.needs_probs_grad = probs.requires_grad def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad): """Backward pass wrapper that calls the custom backward op.""" - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, None, None, None - if not unpermuted_act_grad.is_contiguous(): unpermuted_act_grad = unpermuted_act_grad.contiguous() inp, row_id_map, probs = ctx.saved_tensors - act_grad, prob_grad = moe_unpermute_index_map_backward( + act_grad, prob_grad = torch.ops.te_moe.unpermute_index_map_bwd( unpermuted_act_grad, inp, row_id_map, probs ) @@ -281,23 +259,6 @@ def moe_permute_mask_map_forward( pad_offsets: Optional[torch.Tensor], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Forward pass for MoE permute with mask router map.""" - # Empty input check - if not inp.numel(): - return ( - inp, - torch.tensor([], device=inp.device), - torch.tensor([], device=inp.device), - ) - - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." - assert inp.size(0) == routing_map.size(0), "Permute not possible" - assert num_out_tokens is not None, "num_out_tokens must be provided to the fused permute function." - num_tokens, hidden_size = inp.size() num_experts = routing_map.size(1) @@ -382,7 +343,7 @@ def _moe_permute_mask_map_forward_fake( num_tokens = inp.shape[0] hidden_size = inp.shape[1] num_experts = routing_map.shape[1] - # row_id_map: (num_tokens, num_experts * 2 + 1) — see triton make_row_id_map + # row_id_map: (num_tokens, num_experts * 2 + 1) fake_output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device=inp.device) fake_row_id_map = torch.empty( (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=inp.device @@ -442,25 +403,15 @@ def _moe_permute_mask_map_setup_context(ctx, inputs, output): """Save context for backward pass.""" inp, routing_map, num_out_tokens, probs, pad_offsets = inputs output_tensor, row_id_map, permuted_probs = output - ctx.empty_input = not inp.numel() - if ctx.empty_input and probs is not None: - ctx.save_for_backward(row_id_map, pad_offsets, probs) - else: - ctx.save_for_backward(row_id_map, pad_offsets) - ctx.num_experts = routing_map.size(1) if routing_map.numel() > 0 else 0 + ctx.save_for_backward(row_id_map, pad_offsets) + ctx.num_experts = routing_map.size(1) ctx.num_tokens = inp.size(0) - ctx.hidden_size = inp.size(1) if inp.numel() > 0 else 0 + ctx.hidden_size = inp.size(1) ctx.needs_probs_grad = probs is not None and probs.requires_grad def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, grad_permuted_probs): """Backward wrapper calling the custom backward op.""" - if ctx.empty_input: - if ctx.needs_probs_grad: - _, _, probs = ctx.saved_tensors - return grad_output, None, None, probs, None - return grad_output, None, None, None, None - assert not isinstance( grad_output, QuantizedTensor ), "The backward of moe_permute does not support FP8." @@ -470,7 +421,7 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr # Pass permuted_probs_grad only if it has content probs_grad_input = grad_permuted_probs if grad_permuted_probs.numel() > 0 else None - act_grad, probs_grad = moe_permute_mask_map_backward( + act_grad, probs_grad = torch.ops.te_moe.permute_mask_map_bwd( grad_output, probs_grad_input, row_id_map, pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.hidden_size, ) @@ -500,10 +451,6 @@ def moe_unpermute_mask_map_forward( pad_offsets: Optional[torch.Tensor], ) -> torch.Tensor: """Forward pass for MoE unpermute with mask router map.""" - # Empty input check - if not inp.numel(): - return inp - assert not isinstance( inp, QuantizedTensor ), "The forward of moe_unpermute does not support FP8." @@ -670,7 +617,6 @@ def _moe_unpermute_mask_map_setup_context(ctx, inputs, output): ctx.num_permuted_tokens = inp.size(0) ctx.hidden_size = hidden_size ctx.with_probs = merging_probs is not None - ctx.empty_input = not inp.numel() if ctx.with_probs: ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) ctx.needs_probs_grad = merging_probs.requires_grad @@ -681,13 +627,6 @@ def _moe_unpermute_mask_map_setup_context(ctx, inputs, output): def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): """Backward wrapper calling the appropriate custom backward op.""" - if ctx.empty_input: - # Return merging_probs as its own grad for empty input (matches original behavior) - if ctx.with_probs: - _, _, merging_probs, _ = ctx.saved_tensors - return unpermuted_act_grad, None, merging_probs, None, None, None, None - return unpermuted_act_grad, None, None, None, None, None, None - act_grad = None probs_grad = None @@ -696,13 +635,13 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): assert not isinstance( unpermuted_act_grad, QuantizedTensor ), "The backward of moe_unpermute with merging probs does not support FP8." - act_grad, probs_grad = moe_unpermute_mask_map_backward_with_probs( + act_grad, probs_grad = torch.ops.te_moe.unpermute_mask_map_bwd_with_probs( unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, ) else: row_id_map, pad_offsets = ctx.saved_tensors - act_grad = moe_unpermute_mask_map_backward_no_probs( + act_grad = torch.ops.te_moe.unpermute_mask_map_bwd_no_probs( unpermuted_act_grad, row_id_map, pad_offsets, ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, ) @@ -718,6 +657,16 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): setup_context=_moe_unpermute_mask_map_setup_context, ) +# Register all te_moe custom ops as passthrough in QuantizedTensor.__torch_dispatch__ +# so that FP8 tensors are not unwrapped before entering these ops. +_quantized_tensor_passthrough_ops.update({ + torch.ops.te_moe.permute_mask_map_fwd.default, + torch.ops.te_moe.permute_mask_map_bwd.default, + torch.ops.te_moe.unpermute_mask_map_fwd.default, + torch.ops.te_moe.unpermute_mask_map_bwd_with_probs.default, + torch.ops.te_moe.unpermute_mask_map_bwd_no_probs.default, +}) + def moe_permute( inp: torch.Tensor, @@ -753,10 +702,21 @@ def moe_permute( Options are: 'mask', 'index'. Refer to `routing_map` for more details. """ + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert routing_map.is_cuda, "TransformerEngine needs CUDA." + assert inp.size(0) == routing_map.size(0), "Permute not possible" + if routing_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `routing_map` of Permute is {routing_map.dtype}! " + "The recommended type is torch.int32." + ) + routing_map = routing_map.to(torch.int32) if map_type == "index": - return moe_permute_index_map_forward(inp, routing_map, num_out_tokens, max_token_num) + return torch.ops.te_moe.permute_index_map(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = moe_permute_mask_map_forward( + output, row_id_map, _ = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, num_out_tokens, None, None ) return output, row_id_map @@ -790,7 +750,15 @@ def moe_permute_with_probs( The effective output token count, representing the number of tokens not dropped. By default, set to '-1', meaning no tokens are dropped. """ - output, row_id_map, permuted_probs = moe_permute_mask_map_forward( + if not inp.numel(): + # Keep probs in autograd graph so that probs.grad is an empty tensor + # instead of None after backward (backward compatibility). + return ( + inp + probs.sum() * 0, + probs.sum(dim=1), + torch.tensor([], device=inp.device), + ) + output, row_id_map, permuted_probs = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map @@ -846,7 +814,7 @@ def moe_permute_and_pad_with_probs( [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] ) - output, row_id_map, permuted_probs = moe_permute_mask_map_forward( + output, row_id_map, permuted_probs = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets ) return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert @@ -926,130 +894,184 @@ def moe_unpermute( ) row_id_map = row_id_map.to(torch.int32) - return moe_unpermute_index_map_forward(inp, row_id_map, merging_probs, num_tokens, topK) + return torch.ops.te_moe.unpermute_index_map_fwd(inp, row_id_map, merging_probs, num_tokens, topK) if map_type == "mask": + if not inp.numel(): + # Keep merging_probs in autograd graph so that probs.grad is an empty + # tensor instead of None after backward (backward compatibility). + if merging_probs is not None: + return inp + merging_probs.sum() * 0 + return inp + if restore_shape is None: restore_shape = inp.shape num_tokens, hidden_size = restore_shape - num_experts = (row_id_map.size(1) - 1) // 2 if row_id_map.numel() > 0 else 0 + num_experts = (row_id_map.size(1) - 1) // 2 - if not inp.numel(): - # Pass through custom op even for empty input so probs stays in the graph - pass - else: - if merging_probs is not None: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." + if merging_probs is not None: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." - return moe_unpermute_mask_map_forward( + return torch.ops.te_moe.unpermute_mask_map_fwd( inp, row_id_map, merging_probs, num_tokens, num_experts, hidden_size, pad_offsets, ) raise ValueError("map_type should be one of 'mask' or 'index'") -class _moe_chunk_sort(torch.autograd.Function): - """functional MoE chunk permute""" +# ===================== _moe_chunk_sort custom ops ===================== - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - split_sizes: torch.Tensor, - sorted_idxs: torch.Tensor, - probs: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - if not inp.numel(): - return inp, probs +@torch.library.custom_op("te_moe::chunk_sort_fwd", mutates_args=[]) +def moe_chunk_sort_forward( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + probs: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for MoE chunk sort. Returns (output, permuted_probs, row_id_map).""" + num_tokens, hidden_size = inp.shape + num_splits = split_sizes.size(0) - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert split_sizes.is_cuda, "TransformerEngine needs CUDA." - assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." - - num_tokens, hidden_size = inp.shape - num_splits = split_sizes.size(0) - assert num_splits == sorted_idxs.size(0) - - fp8 = isinstance(inp, Float8Tensor) - if fp8: - fp8_dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype + inp = inp._data - row_id_map = triton_permutation.make_chunk_sort_map( - split_sizes, - sorted_idxs, - num_tokens, - num_splits, + row_id_map = triton_permutation.make_chunk_sort_map( + split_sizes, sorted_idxs, num_tokens, num_splits, + ) + output, permuted_probs = triton_permutation.sort_chunks_by_map( + inp, row_id_map, probs, num_tokens, hidden_size, is_forward=True, + ) + if fp8: + output = Float8Tensor( + data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, + shape=output.shape, dtype=fake_dtype, ) - output, permuted_probs = triton_permutation.sort_chunks_by_map( - inp, - row_id_map, - probs, - num_tokens, - hidden_size, - is_forward=True, + + if permuted_probs is None: + permuted_probs = torch.empty(0, device=output.device) + + return output, permuted_probs, row_id_map + + +@moe_chunk_sort_forward.register_fake +def _moe_chunk_sort_forward_fake( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + probs: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake for shape inference.""" + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] + fake_output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + if probs is not None: + fake_probs = torch.empty((num_tokens,), dtype=probs.dtype, device=inp.device) + else: + fake_probs = torch.empty(0, device=inp.device) + # row_id_map: 1D, size num_tokens + fake_row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device=inp.device) + return fake_output, fake_probs, fake_row_id_map + + +@torch.library.custom_op("te_moe::chunk_sort_bwd", mutates_args=[]) +def moe_chunk_sort_backward( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + num_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE chunk sort.""" + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype + permuted_act_grad = permuted_act_grad._data + + act_grad, probs_grad = triton_permutation.sort_chunks_by_map( + permuted_act_grad, row_id_map, permuted_probs_grad, + num_tokens, hidden_size, is_forward=False, + ) + + if fp8: + act_grad = Float8Tensor( + data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, dtype=fake_dtype, ) - if fp8: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) - ctx.save_for_backward(row_id_map) - ctx.num_tokens = num_tokens - ctx.hidden_size = hidden_size - return output, permuted_probs - - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - permuted_probs_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, permuted_probs_grad - - act_grad = None + if probs_grad is None: + probs_grad = torch.empty(0, device=act_grad.device) + + return act_grad, probs_grad + + +@moe_chunk_sort_backward.register_fake +def _moe_chunk_sort_backward_fake( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + num_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference.""" + fake_act_grad = torch.empty( + (num_tokens, hidden_size), dtype=permuted_act_grad.dtype, device=permuted_act_grad.device, + ) + if permuted_probs_grad is not None: + fake_probs_grad = torch.empty( + (num_tokens,), dtype=permuted_probs_grad.dtype, device=permuted_act_grad.device, + ) + else: + fake_probs_grad = torch.empty(0, device=permuted_act_grad.device) + return fake_act_grad, fake_probs_grad + + +def _moe_chunk_sort_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, split_sizes, sorted_idxs, probs = inputs + output_tensor, permuted_probs, row_id_map = output + + ctx.save_for_backward(row_id_map) + ctx.num_tokens = inp.size(0) + ctx.hidden_size = inp.size(1) + ctx.needs_probs_grad = probs is not None and probs.requires_grad + + +def _moe_chunk_sort_backward_wrapper(ctx, permuted_act_grad, permuted_probs_grad, _row_id_map_grad): + """Backward wrapper calling the custom backward op.""" + (row_id_map,) = ctx.saved_tensors + + probs_grad_input = permuted_probs_grad if permuted_probs_grad.numel() > 0 else None + + act_grad, probs_grad = torch.ops.te_moe.chunk_sort_bwd( + permuted_act_grad, probs_grad_input, row_id_map, + ctx.num_tokens, ctx.hidden_size, + ) + + if not ctx.needs_probs_grad or probs_grad.numel() == 0: probs_grad = None - if ctx.needs_input_grad[0]: - (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(permuted_act_grad, Float8Tensor) - if fp8: - fp8_dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - act_grad, probs_grad = triton_permutation.sort_chunks_by_map( - permuted_act_grad, - row_id_map, - permuted_probs_grad, - ctx.num_tokens, - ctx.hidden_size, - is_forward=False, - ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) - if not ctx.needs_input_grad[3]: - probs_grad = None - return act_grad, None, None, probs_grad + + return act_grad, None, None, probs_grad + + +moe_chunk_sort_forward.register_autograd( + _moe_chunk_sort_backward_wrapper, + setup_context=_moe_chunk_sort_setup_context, +) + +# Register chunk sort ops as passthrough in QuantizedTensor.__torch_dispatch__ +_quantized_tensor_passthrough_ops.update({ + torch.ops.te_moe.chunk_sort_fwd.default, + torch.ops.te_moe.chunk_sort_bwd.default, +}) def moe_sort_chunks_by_index( @@ -1071,7 +1093,9 @@ def moe_sort_chunks_by_index( sorted_indices : torch.Tensor Chunk indices used to permute the chunks. """ - output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) + if not inp.numel(): + return inp + output, _, _ = torch.ops.te_moe.chunk_sort_fwd(inp, split_sizes, sorted_index, None) return output @@ -1099,5 +1123,7 @@ def moe_sort_chunks_by_index_with_probs( sorted_indices : torch.Tensor Chunk indices used to permute the chunks. """ - output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) + if not inp.numel(): + return inp, probs + output, permuted_probs, _ = torch.ops.te_moe.chunk_sort_fwd(inp, split_sizes, sorted_index, probs) return output, permuted_probs diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 678b884812..1bac4b2b53 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -21,6 +21,12 @@ ) +# Custom ops that should pass through __torch_dispatch__ without unwrapping +# QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that +# handle quantized tensors internally. +_quantized_tensor_passthrough_ops: set = set() + + class QuantizedTensorStorage: r"""Base class for all TensorStorage classes. @@ -516,8 +522,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return func(t) return False # Or error out? - # Pass through te_moe custom ops without unwrapping - if hasattr(func, "namespace") and func.namespace == "te_moe": + # Pass through registered custom ops without unwrapping + if func in _quantized_tensor_passthrough_ops: if kwargs is None: kwargs = {} return super().__torch_dispatch__(func, types, args, kwargs) From dcdb413ed9fc73c5614dbf9653ba02a8e79e8061 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 18 Feb 2026 17:32:14 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_permutation.py | 34 +-- transformer_engine/pytorch/permutation.py | 268 ++++++++++++++++------ 2 files changed, 210 insertions(+), 92 deletions(-) diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index bc5957a18b..77796f203b 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -223,6 +223,7 @@ def _maybe_compile(fn, use_torch_compile): if use_torch_compile: torch._dynamo.reset() import torch._functorch.config as functorch_config + functorch_config.donated_buffer = False return torch.compile(fn, fullgraph=True) return fn @@ -310,12 +311,12 @@ def _test_permutation_index_map( te_permute_bwd_input = pytorch_permute_bwd_input.detach() _permute = _maybe_compile( - lambda inp, idx, num_out, max_token: te_permute(inp, idx, num_out, max_token, map_type="index"), + lambda inp, idx, num_out, max_token: te_permute( + inp, idx, num_out, max_token, map_type="index" + ), use_torch_compile, ) - te_permute_output, row_id_map = _permute( - te_permute_fwd_input, indices, num_out_tokens, -1 - ) + te_permute_output, row_id_map = _permute(te_permute_fwd_input, indices, num_out_tokens, -1) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -330,9 +331,7 @@ def _test_permutation_index_map( lambda inp, row_map, probs_val: te_unpermute(inp, row_map, probs_val, map_type="index"), use_torch_compile, ) - te_unpermute_output = _unpermute( - te_unpermute_fwd_input, row_id_map, te_probs - ) + te_unpermute_output = _unpermute(te_unpermute_fwd_input, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -538,9 +537,7 @@ def _test_permutation_mask_map( lambda inp, rmap, n_out: te_permute(inp, rmap, num_out_tokens=n_out, map_type="mask"), use_torch_compile, ) - te_permute_output, row_id_map = _permute( - te_permute_fwd_input, routing_map, num_out_tokens - ) + te_permute_output, row_id_map = _permute(te_permute_fwd_input, routing_map, num_out_tokens) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -555,9 +552,7 @@ def _test_permutation_mask_map( lambda inp, row_map, p, rs: te_unpermute(inp, row_map, p, rs, map_type="mask"), use_torch_compile, ) - te_unpermute_output = _unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, restore_shape - ) + te_unpermute_output = _unpermute(te_unpermute_fwd_input, row_id_map, te_probs, restore_shape) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -1547,7 +1542,9 @@ def _test_permutation_mask_map_alongside_probs( te_probs.requires_grad_(True) def _alongside_probs_fn(fwd_inp, t_probs, rmap, ss1, si1, ss2, si2): - out, pprobs, rid = te_permute_with_probs(fwd_inp, t_probs, rmap, num_out_tokens=num_out_tokens) + out, pprobs, rid = te_permute_with_probs( + fwd_inp, t_probs, rmap, num_out_tokens=num_out_tokens + ) out, pprobs = te_sort_chunks_by_index_with_probs(out, pprobs, ss1, si1) out_dtype = out.dtype out = out * pprobs.unsqueeze(-1) @@ -1558,8 +1555,13 @@ def _alongside_probs_fn(fwd_inp, t_probs, rmap, ss1, si1, ss2, si2): _fn = _maybe_compile(_alongside_probs_fn, use_torch_compile) te_unpermute_output = _fn( - te_permute_fwd_input, te_probs, routing_map, - split_sizes_cuda, sorted_idxs_cuda, split_sizes_2_cuda, sorted_idxs_2_cuda, + te_permute_fwd_input, + te_probs, + routing_map, + split_sizes_cuda, + sorted_idxs_cuda, + split_sizes_2_cuda, + sorted_idxs_2_cuda, ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 7d2cb8f6f9..5b6498e70e 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -81,12 +81,8 @@ def _moe_permute_index_map_fake( output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK # row_id_map is 1D with size = num_tokens * topK - fake_output = torch.empty( - (output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device - ) - fake_row_id_map = torch.empty( - (num_tokens * topK,), dtype=torch.int32, device=inp.device - ) + fake_output = torch.empty((output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device) + fake_row_id_map = torch.empty((num_tokens * topK,), dtype=torch.int32, device=inp.device) return fake_output, fake_row_id_map @@ -151,6 +147,7 @@ def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_ # ===================== _moe_unpermute_index_map custom ops ===================== + @torch.library.custom_op("te_moe::unpermute_index_map_fwd", mutates_args=[]) def moe_unpermute_index_map_forward( inp: torch.Tensor, @@ -174,9 +171,7 @@ def _moe_unpermute_index_map_forward_fake( ) -> torch.Tensor: """Fake implementation for shape inference.""" # Output shape: (num_tokens, hidden_size) - return torch.empty( - (num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device - ) + return torch.empty((num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device) @torch.library.custom_op("te_moe::unpermute_index_map_bwd", mutates_args=[]) @@ -217,7 +212,6 @@ def _moe_unpermute_index_map_backward_fake( return act_grad, prob_grad - def _moe_unpermute_index_map_setup_context(ctx, inputs, output): """Save context for backward pass.""" inp, row_id_map, probs, num_tokens, topK = inputs @@ -250,6 +244,7 @@ def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad): # ===================== _moe_permute_mask_map custom ops ===================== + @torch.library.custom_op("te_moe::permute_mask_map_fwd", mutates_args=[]) def moe_permute_mask_map_forward( inp: torch.Tensor, @@ -296,30 +291,51 @@ def moe_permute_mask_map_forward( scale_hidden_dim = None output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( - inp, row_id_map, probs, fp8_scale, pad_offsets, - num_tokens, num_experts, num_out_tokens, hidden_size, scale_hidden_dim, + inp, + row_id_map, + probs, + fp8_scale, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + scale_hidden_dim, ) if fp8: if per_tensor_recipe: output = Float8Tensor( - data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, - shape=output.shape, dtype=fake_dtype, + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, ) elif blockwise_recipe: output = Float8BlockwiseQTensor( - shape=output.shape, dtype=fake_dtype, rowwise_data=output, + shape=output.shape, + dtype=fake_dtype, + rowwise_data=output, rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, requires_grad=output.requires_grad, ) elif mxfp8_recipe: output = MXFP8Tensor( - shape=output.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype, - rowwise_data=output, rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, columnwise_scale_inv=None, - quantizer=None, requires_grad=output.requires_grad, + shape=output.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=output.requires_grad, with_gemm_swizzled_scales=False, ) @@ -327,7 +343,6 @@ def moe_permute_mask_map_forward( if permuted_probs is None: permuted_probs = torch.empty(0, device=inp.device) - return output, row_id_map, permuted_probs @@ -367,8 +382,14 @@ def moe_permute_mask_map_backward( ) -> Tuple[torch.Tensor, torch.Tensor]: """Backward pass for MoE permute with mask router map.""" act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( - permuted_act_grad, row_id_map, None, permuted_probs_grad, pad_offsets, - num_tokens, num_experts, hidden_size, + permuted_act_grad, + row_id_map, + None, + permuted_probs_grad, + pad_offsets, + num_tokens, + num_experts, + hidden_size, ) if probs_grad is None: probs_grad = torch.empty(0, device=permuted_act_grad.device) @@ -391,7 +412,8 @@ def _moe_permute_mask_map_backward_fake( ) if permuted_probs_grad is not None: probs_grad = torch.empty( - (num_tokens, num_experts), dtype=permuted_probs_grad.dtype, + (num_tokens, num_experts), + dtype=permuted_probs_grad.dtype, device=permuted_act_grad.device, ) else: @@ -422,8 +444,13 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr probs_grad_input = grad_permuted_probs if grad_permuted_probs.numel() > 0 else None act_grad, probs_grad = torch.ops.te_moe.permute_mask_map_bwd( - grad_output, probs_grad_input, row_id_map, pad_offsets, - ctx.num_tokens, ctx.num_experts, ctx.hidden_size, + grad_output, + probs_grad_input, + row_id_map, + pad_offsets, + ctx.num_tokens, + ctx.num_experts, + ctx.hidden_size, ) if not ctx.needs_probs_grad or probs_grad.numel() == 0: @@ -440,6 +467,7 @@ def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, gr # ===================== _moe_unpermute_mask_map custom ops ===================== + @torch.library.custom_op("te_moe::unpermute_mask_map_fwd", mutates_args=[]) def moe_unpermute_mask_map_forward( inp: torch.Tensor, @@ -455,8 +483,14 @@ def moe_unpermute_mask_map_forward( inp, QuantizedTensor ), "The forward of moe_unpermute does not support FP8." unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( - inp, row_id_map, merging_probs, None, pad_offsets, - num_tokens, num_experts, hidden_size, + inp, + row_id_map, + merging_probs, + None, + pad_offsets, + num_tokens, + num_experts, + hidden_size, ) return unpermuted_output @@ -489,8 +523,15 @@ def moe_unpermute_mask_map_backward_with_probs( ) -> Tuple[torch.Tensor, torch.Tensor]: """Backward pass for MoE unpermute with merging probs.""" act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( - unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets, - num_tokens, num_experts, num_permuted_tokens, hidden_size, + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + pad_offsets, + num_tokens, + num_experts, + num_permuted_tokens, + hidden_size, ) return act_grad, probs_grad @@ -510,11 +551,13 @@ def _moe_unpermute_mask_map_bwd_with_probs_fake( """Fake for backward shape inference with merging probs.""" act_grad = torch.empty( (num_permuted_tokens, hidden_size), - dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device, + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, ) probs_grad = torch.empty( (num_tokens, num_experts), - dtype=merging_probs.dtype, device=unpermuted_act_grad.device, + dtype=merging_probs.dtype, + device=unpermuted_act_grad.device, ) return act_grad, probs_grad @@ -562,30 +605,51 @@ def moe_unpermute_mask_map_backward_no_probs( fp8_scale = None act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( - unpermuted_act_grad, row_id_map, None, fp8_scale, pad_offsets, - num_tokens, num_experts, num_permuted_tokens, hidden_size, scale_hidden_dim, + unpermuted_act_grad, + row_id_map, + None, + fp8_scale, + pad_offsets, + num_tokens, + num_experts, + num_permuted_tokens, + hidden_size, + scale_hidden_dim, ) if fp8: if per_tensor_recipe: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, dtype=fake_dtype, + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) elif blockwise_recipe: act_grad = Float8BlockwiseQTensor( - shape=act_grad.shape, dtype=fake_dtype, rowwise_data=act_grad, + shape=act_grad.shape, + dtype=fake_dtype, + rowwise_data=act_grad, rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, quantizer=None, is_2D_scaled=False, + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, requires_grad=act_grad.requires_grad, ) elif mxfp8_recipe: act_grad = MXFP8Tensor( - shape=act_grad.shape, dtype=fake_dtype, fp8_dtype=fp8_dtype, - rowwise_data=act_grad, rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, columnwise_scale_inv=None, - quantizer=None, requires_grad=act_grad.requires_grad, + shape=act_grad.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=act_grad.requires_grad, with_gemm_swizzled_scales=False, ) @@ -605,7 +669,8 @@ def _moe_unpermute_mask_map_bwd_no_probs_fake( """Fake for backward shape inference without probs.""" return torch.empty( (num_permuted_tokens, hidden_size), - dtype=unpermuted_act_grad.dtype, device=unpermuted_act_grad.device, + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, ) @@ -636,14 +701,26 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): unpermuted_act_grad, QuantizedTensor ), "The backward of moe_unpermute with merging probs does not support FP8." act_grad, probs_grad = torch.ops.te_moe.unpermute_mask_map_bwd_with_probs( - unpermuted_act_grad, row_id_map, fwd_input, merging_probs, pad_offsets, - ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + pad_offsets, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, ) else: row_id_map, pad_offsets = ctx.saved_tensors act_grad = torch.ops.te_moe.unpermute_mask_map_bwd_no_probs( - unpermuted_act_grad, row_id_map, pad_offsets, - ctx.num_tokens, ctx.num_experts, ctx.num_permuted_tokens, ctx.hidden_size, + unpermuted_act_grad, + row_id_map, + pad_offsets, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, ) if not ctx.needs_probs_grad: @@ -659,13 +736,15 @@ def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): # Register all te_moe custom ops as passthrough in QuantizedTensor.__torch_dispatch__ # so that FP8 tensors are not unwrapped before entering these ops. -_quantized_tensor_passthrough_ops.update({ - torch.ops.te_moe.permute_mask_map_fwd.default, - torch.ops.te_moe.permute_mask_map_bwd.default, - torch.ops.te_moe.unpermute_mask_map_fwd.default, - torch.ops.te_moe.unpermute_mask_map_bwd_with_probs.default, - torch.ops.te_moe.unpermute_mask_map_bwd_no_probs.default, -}) +_quantized_tensor_passthrough_ops.update( + { + torch.ops.te_moe.permute_mask_map_fwd.default, + torch.ops.te_moe.permute_mask_map_bwd.default, + torch.ops.te_moe.unpermute_mask_map_fwd.default, + torch.ops.te_moe.unpermute_mask_map_bwd_with_probs.default, + torch.ops.te_moe.unpermute_mask_map_bwd_no_probs.default, + } +) def moe_permute( @@ -894,7 +973,9 @@ def moe_unpermute( ) row_id_map = row_id_map.to(torch.int32) - return torch.ops.te_moe.unpermute_index_map_fwd(inp, row_id_map, merging_probs, num_tokens, topK) + return torch.ops.te_moe.unpermute_index_map_fwd( + inp, row_id_map, merging_probs, num_tokens, topK + ) if map_type == "mask": if not inp.numel(): # Keep merging_probs in autograd graph so that probs.grad is an empty @@ -916,14 +997,20 @@ def moe_unpermute( assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." return torch.ops.te_moe.unpermute_mask_map_fwd( - inp, row_id_map, merging_probs, - num_tokens, num_experts, hidden_size, pad_offsets, + inp, + row_id_map, + merging_probs, + num_tokens, + num_experts, + hidden_size, + pad_offsets, ) raise ValueError("map_type should be one of 'mask' or 'index'") # ===================== _moe_chunk_sort custom ops ===================== + @torch.library.custom_op("te_moe::chunk_sort_fwd", mutates_args=[]) def moe_chunk_sort_forward( inp: torch.Tensor, @@ -943,15 +1030,26 @@ def moe_chunk_sort_forward( inp = inp._data row_id_map = triton_permutation.make_chunk_sort_map( - split_sizes, sorted_idxs, num_tokens, num_splits, + split_sizes, + sorted_idxs, + num_tokens, + num_splits, ) output, permuted_probs = triton_permutation.sort_chunks_by_map( - inp, row_id_map, probs, num_tokens, hidden_size, is_forward=True, + inp, + row_id_map, + probs, + num_tokens, + hidden_size, + is_forward=True, ) if fp8: output = Float8Tensor( - data=output, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, - shape=output.shape, dtype=fake_dtype, + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, ) if permuted_probs is None: @@ -997,14 +1095,21 @@ def moe_chunk_sort_backward( permuted_act_grad = permuted_act_grad._data act_grad, probs_grad = triton_permutation.sort_chunks_by_map( - permuted_act_grad, row_id_map, permuted_probs_grad, - num_tokens, hidden_size, is_forward=False, + permuted_act_grad, + row_id_map, + permuted_probs_grad, + num_tokens, + hidden_size, + is_forward=False, ) if fp8: act_grad = Float8Tensor( - data=act_grad, fp8_dtype=fp8_dtype, fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, dtype=fake_dtype, + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, ) if probs_grad is None: @@ -1023,11 +1128,15 @@ def _moe_chunk_sort_backward_fake( ) -> Tuple[torch.Tensor, torch.Tensor]: """Fake for backward shape inference.""" fake_act_grad = torch.empty( - (num_tokens, hidden_size), dtype=permuted_act_grad.dtype, device=permuted_act_grad.device, + (num_tokens, hidden_size), + dtype=permuted_act_grad.dtype, + device=permuted_act_grad.device, ) if permuted_probs_grad is not None: fake_probs_grad = torch.empty( - (num_tokens,), dtype=permuted_probs_grad.dtype, device=permuted_act_grad.device, + (num_tokens,), + dtype=permuted_probs_grad.dtype, + device=permuted_act_grad.device, ) else: fake_probs_grad = torch.empty(0, device=permuted_act_grad.device) @@ -1052,8 +1161,11 @@ def _moe_chunk_sort_backward_wrapper(ctx, permuted_act_grad, permuted_probs_grad probs_grad_input = permuted_probs_grad if permuted_probs_grad.numel() > 0 else None act_grad, probs_grad = torch.ops.te_moe.chunk_sort_bwd( - permuted_act_grad, probs_grad_input, row_id_map, - ctx.num_tokens, ctx.hidden_size, + permuted_act_grad, + probs_grad_input, + row_id_map, + ctx.num_tokens, + ctx.hidden_size, ) if not ctx.needs_probs_grad or probs_grad.numel() == 0: @@ -1068,10 +1180,12 @@ def _moe_chunk_sort_backward_wrapper(ctx, permuted_act_grad, permuted_probs_grad ) # Register chunk sort ops as passthrough in QuantizedTensor.__torch_dispatch__ -_quantized_tensor_passthrough_ops.update({ - torch.ops.te_moe.chunk_sort_fwd.default, - torch.ops.te_moe.chunk_sort_bwd.default, -}) +_quantized_tensor_passthrough_ops.update( + { + torch.ops.te_moe.chunk_sort_fwd.default, + torch.ops.te_moe.chunk_sort_bwd.default, + } +) def moe_sort_chunks_by_index( @@ -1125,5 +1239,7 @@ def moe_sort_chunks_by_index_with_probs( """ if not inp.numel(): return inp, probs - output, permuted_probs, _ = torch.ops.te_moe.chunk_sort_fwd(inp, split_sizes, sorted_index, probs) + output, permuted_probs, _ = torch.ops.te_moe.chunk_sort_fwd( + inp, split_sizes, sorted_index, probs + ) return output, permuted_probs