diff --git a/tests/pytorch/test_permutation.py b/tests/pytorch/test_permutation.py index be1ff30472..77796f203b 100644 --- a/tests/pytorch/test_permutation.py +++ b/tests/pytorch/test_permutation.py @@ -218,6 +218,17 @@ def backward_wrapper( return act.backward(backward_input, retain_graph=retain_graph) +def _maybe_compile(fn, use_torch_compile): + """Wrap fn with torch.compile(fullgraph=True) if requested.""" + if use_torch_compile: + torch._dynamo.reset() + import torch._functorch.config as functorch_config + + functorch_config.donated_buffer = False + return torch.compile(fn, fullgraph=True) + return fn + + def _test_permutation_index_map( te_dtype, num_tokens, @@ -227,6 +238,7 @@ def _test_permutation_index_map( num_out_tokens, with_probs, BENCHMARK=False, + use_torch_compile=False, ): if not with_probs and topK > 1: pytest.skip("Only permutations with topK=1 and without probabilities are supported.") @@ -298,9 +310,13 @@ def _test_permutation_index_map( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, indices, num_out_tokens, map_type="index" + _permute = _maybe_compile( + lambda inp, idx, num_out, max_token: te_permute( + inp, idx, num_out, max_token, map_type="index" + ), + use_torch_compile, ) + te_permute_output, row_id_map = _permute(te_permute_fwd_input, indices, num_out_tokens, -1) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -311,9 +327,11 @@ def _test_permutation_index_map( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, map_type="index" + _unpermute = _maybe_compile( + lambda inp, row_map, probs_val: te_unpermute(inp, row_map, probs_val, map_type="index"), + use_torch_compile, ) + te_unpermute_output = _unpermute(te_unpermute_fwd_input, row_id_map, te_probs) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -444,6 +462,7 @@ def _test_permutation_mask_map( num_out_tokens, with_probs, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -514,9 +533,11 @@ def _test_permutation_mask_map( te_permute_fwd_input.requires_grad_(True) te_permute_bwd_input = pytorch_permute_bwd_input.detach() - te_permute_output, row_id_map = te_permute( - te_permute_fwd_input, routing_map, num_out_tokens=num_out_tokens, map_type="mask" + _permute = _maybe_compile( + lambda inp, rmap, n_out: te_permute(inp, rmap, num_out_tokens=n_out, map_type="mask"), + use_torch_compile, ) + te_permute_output, row_id_map = _permute(te_permute_fwd_input, routing_map, num_out_tokens) te_permute_output.backward(te_permute_bwd_input, retain_graph=True) te_probs = None @@ -527,9 +548,11 @@ def _test_permutation_mask_map( te_unpermute_fwd_input.requires_grad_(True) te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach() - te_unpermute_output = te_unpermute( - te_unpermute_fwd_input, row_id_map, te_probs, restore_shape, map_type="mask" + _unpermute = _maybe_compile( + lambda inp, row_map, p, rs: te_unpermute(inp, row_map, p, rs, map_type="mask"), + use_torch_compile, ) + te_unpermute_output = _unpermute(te_unpermute_fwd_input, row_id_map, te_probs, restore_shape) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) ################################################################################################################################### @@ -666,6 +689,7 @@ def _test_permutation_and_padding_mask_map( with_merging_probs=False, align_size=16, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -957,6 +981,7 @@ def _test_permutation_and_padding_with_merging_probs( num_out_tokens, align_size=16, BENCHMARK=False, + use_torch_compile=False, ): """ Test the combination of merging_probs AND pad_offsets together in moe_unpermute. @@ -1291,6 +1316,7 @@ def _test_moe_chunk_sort( tp_size, hidden_size, BENCHMARK=False, + use_torch_compile=False, ): print( "chunk permute:" @@ -1340,7 +1366,11 @@ def _test_moe_chunk_sort( te_fwd_input.requires_grad_(True) te_bwd_input = pytorch_bwd_input.detach() - te_output = te_sort_chunks_by_index(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) + _sort = _maybe_compile( + lambda inp, ss, si: te_sort_chunks_by_index(inp, ss, si), + use_torch_compile, + ) + te_output = _sort(te_fwd_input, split_sizes_cuda, sorted_idxs_cuda) te_output.backward(te_bwd_input, retain_graph=True) ################################################################################################################################### @@ -1415,6 +1445,7 @@ def _test_permutation_mask_map_alongside_probs( num_out_tokens, tp_size, BENCHMARK=False, + use_torch_compile=False, ): if topK > num_expert: pytest.skip("topK should be smaller than the number of experts.") @@ -1510,30 +1541,27 @@ def _test_permutation_mask_map_alongside_probs( te_probs = probs.detach() te_probs.requires_grad_(True) - te_permute_output, te_permuted_probs, row_id_map = te_permute_with_probs( + def _alongside_probs_fn(fwd_inp, t_probs, rmap, ss1, si1, ss2, si2): + out, pprobs, rid = te_permute_with_probs( + fwd_inp, t_probs, rmap, num_out_tokens=num_out_tokens + ) + out, pprobs = te_sort_chunks_by_index_with_probs(out, pprobs, ss1, si1) + out_dtype = out.dtype + out = out * pprobs.unsqueeze(-1) + out = out.to(dtype=out_dtype) + out = te_sort_chunks_by_index(out, ss2, si2) + out = te_unpermute(out, rid, restore_shape=restore_shape, map_type="mask") + return out + + _fn = _maybe_compile(_alongside_probs_fn, use_torch_compile) + te_unpermute_output = _fn( te_permute_fwd_input, te_probs, routing_map, - num_out_tokens=num_out_tokens, - ) - - te_permute_output, te_permuted_probs = te_sort_chunks_by_index_with_probs( - te_permute_output, te_permuted_probs, split_sizes_cuda, sorted_idxs_cuda - ) - - te_permute_output_dtype = te_permute_output.dtype - te_permute_output = te_permute_output * te_permuted_probs.unsqueeze(-1) - te_permute_output = te_permute_output.to(dtype=te_permute_output_dtype) - - te_permute_output = te_sort_chunks_by_index( - te_permute_output, split_sizes_2_cuda, sorted_idxs_2_cuda - ) - - te_unpermute_output = te_unpermute( - te_permute_output, - row_id_map, - restore_shape=restore_shape, - map_type="mask", + split_sizes_cuda, + sorted_idxs_cuda, + split_sizes_2_cuda, + sorted_idxs_2_cuda, ) te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True) @@ -1647,6 +1675,11 @@ def perf_test_cuda_kernel(cuda_kernel_fn): @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) +@pytest.mark.skipif( + torch.__version__ < "2", + reason="torch.compile not available - skipping torch.compile tests", +) def test_permutation_index_map( te_dtype, num_tokens, @@ -1654,7 +1687,12 @@ def test_permutation_index_map( hidden_size, topK, num_out_tokens, + use_torch_compile, ): + if use_torch_compile and torch.__version__ < "2": + pytest.skip("torch.compile not available") + if use_torch_compile and (num_expert != 7 or topK != 2): + pytest.skip("torch.compile tested with single config only") with_probs = True BENCHMARK = False @@ -1667,6 +1705,7 @@ def test_permutation_index_map( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1676,6 +1715,7 @@ def test_permutation_index_map( @pytest.mark.parametrize("hidden_size", [4096]) @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map( te_dtype, num_tokens, @@ -1683,7 +1723,10 @@ def test_permutation_mask_map( hidden_size, topK, num_out_tokens, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or topK != 2): + pytest.skip("torch.compile tested with single config only") with_probs = True BENCHMARK = False @@ -1696,6 +1739,7 @@ def test_permutation_mask_map( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1711,6 +1755,7 @@ def test_permutation_mask_map( ], ) @pytest.mark.parametrize("with_merging_probs", [True, False]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_and_padding_mask_map( te_dtype, num_tokens, @@ -1719,7 +1764,10 @@ def test_permutation_and_padding_mask_map( topK, num_out_tokens, with_merging_probs, + use_torch_compile, ): + if use_torch_compile and (num_expert != 8 or topK != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_permutation_and_padding_mask_map( @@ -1731,6 +1779,7 @@ def test_permutation_and_padding_mask_map( num_out_tokens=num_out_tokens, with_merging_probs=with_merging_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1745,6 +1794,7 @@ def test_permutation_and_padding_mask_map( (4096, 512, 9216, 8), ], ) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_and_padding_with_merging_probs( te_dtype, num_tokens, @@ -1752,8 +1802,11 @@ def test_permutation_and_padding_with_merging_probs( hidden_size, topK, num_out_tokens, + use_torch_compile, ): """Test moe_unpermute backward pass with BOTH merging_probs AND pad_offsets.""" + if use_torch_compile and (num_expert != 8 or topK != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_permutation_and_padding_with_merging_probs( @@ -1764,11 +1817,13 @@ def test_permutation_and_padding_with_merging_probs( topK=topK, num_out_tokens=num_out_tokens, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_permutation_mask_map_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_permutation_mask_map_empty_input(te_dtype, use_torch_compile): with_probs = True BENCHMARK = False @@ -1781,6 +1836,7 @@ def test_permutation_mask_map_empty_input(te_dtype): num_out_tokens=0, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1791,6 +1847,7 @@ def test_permutation_mask_map_empty_input(te_dtype): @pytest.mark.parametrize("topK", [2, 5]) @pytest.mark.parametrize("num_out_tokens", [None, 2039]) @pytest.mark.parametrize("tp_size", [1, 2]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map_alongside_probs( te_dtype, num_tokens, @@ -1799,7 +1856,10 @@ def test_permutation_mask_map_alongside_probs( topK, num_out_tokens, tp_size, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or topK != 2 or tp_size != 1): + pytest.skip("torch.compile tested with single config only") _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, num_tokens=num_tokens, @@ -1808,11 +1868,13 @@ def test_permutation_mask_map_alongside_probs( topK=topK, num_out_tokens=num_out_tokens, tp_size=tp_size, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_permutation_mask_map_alongside_probs_empty_input(te_dtype, use_torch_compile): _test_permutation_mask_map_alongside_probs( te_dtype=te_dtype, num_tokens=0, @@ -1821,6 +1883,7 @@ def test_permutation_mask_map_alongside_probs_empty_input(te_dtype): topK=2, num_out_tokens=0, tp_size=2, + use_torch_compile=use_torch_compile, ) @@ -1875,12 +1938,22 @@ def test_permutation_mask_map_fp8( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) +@pytest.mark.skipif( + torch.__version__ < "2", + reason="torch.compile not available - skipping torch.compile tests", +) def test_permutation_index_map_topk1_no_probs( te_dtype, num_tokens, num_expert, hidden_size, + use_torch_compile, ): + if use_torch_compile and torch.__version__ < "2": + pytest.skip("torch.compile not available") + if use_torch_compile and num_expert != 7: + pytest.skip("torch.compile tested with single config only") topK = 1 num_out_tokens = None with_probs = False @@ -1895,6 +1968,7 @@ def test_permutation_index_map_topk1_no_probs( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1902,12 +1976,16 @@ def test_permutation_index_map_topk1_no_probs( @pytest.mark.parametrize("num_tokens", [4096]) @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_permutation_mask_map_topk1_no_probs( te_dtype, num_tokens, num_expert, hidden_size, + use_torch_compile, ): + if use_torch_compile and num_expert != 7: + pytest.skip("torch.compile tested with single config only") topK = 1 num_out_tokens = None with_probs = False @@ -1922,6 +2000,7 @@ def test_permutation_mask_map_topk1_no_probs( num_out_tokens=num_out_tokens, with_probs=with_probs, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @@ -1930,13 +2009,17 @@ def test_permutation_mask_map_topk1_no_probs( @pytest.mark.parametrize("num_expert", [7, 16]) @pytest.mark.parametrize("tp_size", [2, 8]) @pytest.mark.parametrize("hidden_size", [4096]) +@pytest.mark.parametrize("use_torch_compile", [False, True]) def test_chunk_permutation( te_dtype, num_tokens, num_expert, tp_size, hidden_size, + use_torch_compile, ): + if use_torch_compile and (num_expert != 7 or tp_size != 2): + pytest.skip("torch.compile tested with single config only") BENCHMARK = False _test_moe_chunk_sort( @@ -1946,11 +2029,13 @@ def test_chunk_permutation( tp_size=tp_size, hidden_size=hidden_size, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) @pytest.mark.parametrize("te_dtype", _te_dtypes) -def test_chunk_permutation_empty_input(te_dtype): +@pytest.mark.parametrize("use_torch_compile", [False, True]) +def test_chunk_permutation_empty_input(te_dtype, use_torch_compile): BENCHMARK = False _test_moe_chunk_sort( @@ -1960,6 +2045,7 @@ def test_chunk_permutation_empty_input(te_dtype): tp_size=2, hidden_size=4096, BENCHMARK=BENCHMARK, + use_torch_compile=use_torch_compile, ) diff --git a/transformer_engine/pytorch/permutation.py b/transformer_engine/pytorch/permutation.py index 5beeed1262..5b6498e70e 100644 --- a/transformer_engine/pytorch/permutation.py +++ b/transformer_engine/pytorch/permutation.py @@ -7,10 +7,16 @@ from typing import Optional, Tuple import torch +# Allow warnings.warn inside torch.compile without graph breaks +torch._dynamo.config.reorderable_logging_functions.add(warnings.warn) + import transformer_engine_torch as tex import transformer_engine.pytorch.triton.permutation as triton_permutation from transformer_engine.pytorch.constants import TE_DType -from transformer_engine.pytorch.quantized_tensor import QuantizedTensor +from transformer_engine.pytorch.quantized_tensor import ( + QuantizedTensor, + _quantized_tensor_passthrough_ops, +) from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockwiseQTensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor @@ -22,495 +28,723 @@ ] -class _moe_permute_index_map(torch.autograd.Function): - """functional Permute with index router map""" +# ===================== _moe_permute_index_map custom ops ===================== - workspace = None - max_expanded_token_num = 0 +# Workspace state for moe_permute_index_map +_moe_permute_index_map_workspace = None +_moe_permute_index_map_max_expanded_token_num = 0 - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - index: torch.Tensor, - num_out_tokens: int, - max_token_num: int, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - # Empty input check - if not inp.numel(): - return inp, torch.tensor([], device=inp.device) - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert index.is_cuda, "TransformerEngine needs CUDA." - # Shape check - assert inp.size(0) == index.size(0), "Permute not possible" +@torch.library.custom_op("te_moe::permute_index_map", mutates_args=[]) +def moe_permute_index_map_forward( + inp: torch.Tensor, + index: torch.Tensor, + num_out_tokens: int, + max_token_num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Forward pass for MoE permute with index router map.""" + global _moe_permute_index_map_workspace, _moe_permute_index_map_max_expanded_token_num - # Data type check - dtype = TE_DType[inp.dtype] - if index.dtype != torch.int32: - warnings.warn( - f"The data type of the input `index` of Permute is {index.dtype}! " - "The recommended type is torch.int32." - ) - index = index.to(torch.int32) + dtype = TE_DType[inp.dtype] - topK = index.size(1) + topK = index.size(1) - input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK - if _moe_permute_index_map.max_expanded_token_num < input_max_expanded_token_num: - _moe_permute_index_map.max_expanded_token_num = input_max_expanded_token_num - _moe_permute_index_map.workspace = [] + input_max_expanded_token_num = max(max_token_num, inp.size(0)) * topK + if _moe_permute_index_map_max_expanded_token_num < input_max_expanded_token_num: + _moe_permute_index_map_max_expanded_token_num = input_max_expanded_token_num + _moe_permute_index_map_workspace = [] - permuted_act, row_id_map, _moe_permute_index_map.workspace = tex.moe_permute_fwd( - inp, - dtype, - index, - num_out_tokens, - _moe_permute_index_map.workspace, - _moe_permute_index_map.max_expanded_token_num, - ) + permuted_act, row_id_map, _moe_permute_index_map_workspace = tex.moe_permute_fwd( + inp, + dtype, + index, + num_out_tokens, + _moe_permute_index_map_workspace, + _moe_permute_index_map_max_expanded_token_num, + ) - ctx.row_id_map = row_id_map - ctx.num_tokens = index.size(0) - ctx.topK = index.size(1) - return permuted_act, row_id_map - - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - _, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - # Empty input check - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, None + return permuted_act, row_id_map - if not permuted_act_grad.is_contiguous(): - permuted_act_grad = permuted_act_grad.contiguous() - dtype = TE_DType[permuted_act_grad.dtype] - act_grad = None - if ctx.needs_input_grad[0]: - act_grad = tex.moe_permute_bwd( - permuted_act_grad, dtype, ctx.row_id_map, torch.empty(0), ctx.num_tokens, ctx.topK - ) +@moe_permute_index_map_forward.register_fake +def _moe_permute_index_map_fake( + inp: torch.Tensor, + index: torch.Tensor, + num_out_tokens: int, + max_token_num: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference.""" + num_tokens = inp.shape[0] + topK = index.shape[1] - return act_grad, None, None, None + # Infer output shape + output_tokens = num_out_tokens if num_out_tokens > 0 else num_tokens * topK + # row_id_map is 1D with size = num_tokens * topK + fake_output = torch.empty((output_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device) + fake_row_id_map = torch.empty((num_tokens * topK,), dtype=torch.int32, device=inp.device) -class _moe_unpermute_index_map(torch.autograd.Function): - """functional Unpermute with index router map""" + return fake_output, fake_row_id_map - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - row_id_map: torch.Tensor, - probs: torch.Tensor, - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - # Empty input check - if not inp.numel(): - ctx.probs = probs - return inp - # None probs check - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." +@torch.library.custom_op("te_moe::permute_index_map_bwd", mutates_args=[]) +def moe_permute_index_map_backward( + grad_permuted_act: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Backward pass for MoE permute with index router map.""" + dtype = TE_DType[grad_permuted_act.dtype] + act_grad = tex.moe_permute_bwd( + grad_permuted_act, dtype, row_id_map, torch.empty(0), num_tokens, topK + ) + return act_grad - if probs.dtype != torch.float32: - warnings.warn( - f"The data type of the input `probs` of Unpermute is {probs.dtype}! " - "The recommended type is torch.float32." - ) - probs = probs.to(torch.float32) - num_tokens = probs.size(0) - topK = probs.size(1) - else: - num_tokens = row_id_map.size(0) - topK = 1 - probs = torch.empty(0) +@moe_permute_index_map_backward.register_fake +def _moe_permute_index_map_backward_fake( + grad_permuted_act: torch.Tensor, + row_id_map: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Fake implementation for shape inference of backward.""" + return torch.empty( + (num_tokens, grad_permuted_act.shape[1]), + dtype=grad_permuted_act.dtype, + device=grad_permuted_act.device, + ) - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - # Data type check - dtype = TE_DType[inp.dtype] - if row_id_map.dtype != torch.int32: - warnings.warn( - f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " - "The recommended type is torch.int32." - ) - row_id_map = row_id_map.to(torch.int32) +def _moe_permute_index_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, index, num_out_tokens, max_token_num = inputs + permuted_act, row_id_map = output + ctx.save_for_backward(row_id_map) + ctx.num_tokens = index.size(0) + ctx.topK = index.size(1) - unpermuted_output = tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) - ctx.save_for_backward(inp, row_id_map, probs) - return unpermuted_output +def _moe_permute_index_map_backward_wrapper(ctx, grad_permuted_act, grad_row_id_map): + """Backward pass wrapper that calls the custom backward op.""" + if not grad_permuted_act.is_contiguous(): + grad_permuted_act = grad_permuted_act.contiguous() + + (row_id_map,) = ctx.saved_tensors + act_grad = torch.ops.te_moe.permute_index_map_bwd( + grad_permuted_act, row_id_map, ctx.num_tokens, ctx.topK + ) + + return act_grad, None, None, None + + +moe_permute_index_map_forward.register_autograd( + _moe_permute_index_map_backward_wrapper, + setup_context=_moe_permute_index_map_setup_context, +) + + +# ===================== _moe_unpermute_index_map custom ops ===================== + + +@torch.library.custom_op("te_moe::unpermute_index_map_fwd", mutates_args=[]) +def moe_unpermute_index_map_forward( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Forward pass for MoE unpermute with index router map.""" + dtype = TE_DType[inp.dtype] + return tex.moe_unpermute_fwd(inp, dtype, row_id_map, probs, num_tokens, topK) + + +@moe_unpermute_index_map_forward.register_fake +def _moe_unpermute_index_map_forward_fake( + inp: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, + num_tokens: int, + topK: int, +) -> torch.Tensor: + """Fake implementation for shape inference.""" + # Output shape: (num_tokens, hidden_size) + return torch.empty((num_tokens, inp.shape[1]), dtype=inp.dtype, device=inp.device) - @staticmethod - def backward( - ctx, - unpermuted_act_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, None, torch.Tensor]: - # pylint: disable=missing-function-docstring - # Empty input check - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.probs - if not unpermuted_act_grad.is_contiguous(): - unpermuted_act_grad = unpermuted_act_grad.contiguous() +@torch.library.custom_op("te_moe::unpermute_index_map_bwd", mutates_args=[]) +def moe_unpermute_index_map_backward( + unpermuted_act_grad: torch.Tensor, + fwd_input: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE unpermute with index router map.""" + dtype = TE_DType[unpermuted_act_grad.dtype] + act_grad, prob_grad = tex.moe_unpermute_bwd( + unpermuted_act_grad, fwd_input, dtype, row_id_map, probs + ) + return act_grad, prob_grad + + +@moe_unpermute_index_map_backward.register_fake +def _moe_unpermute_index_map_backward_fake( + unpermuted_act_grad: torch.Tensor, + fwd_input: torch.Tensor, + row_id_map: torch.Tensor, + probs: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference of backward.""" + # act_grad shape: (fwd_input.size(0), hidden_size) + # prob_grad shape: (num_tokens, topK) + topK = probs.size(1) if probs.numel() > 0 else 1 + num_tokens = probs.size(0) if probs.numel() > 0 else row_id_map.size(0) + act_grad = torch.empty( + (fwd_input.size(0), unpermuted_act_grad.shape[1]), + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, + ) + prob_grad = torch.empty( + (num_tokens, topK), dtype=torch.float32, device=unpermuted_act_grad.device + ) + return act_grad, prob_grad + + +def _moe_unpermute_index_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, row_id_map, probs, num_tokens, topK = inputs + ctx.save_for_backward(inp, row_id_map, probs) + ctx.needs_probs_grad = probs.requires_grad + + +def _moe_unpermute_index_map_backward_wrapper(ctx, unpermuted_act_grad): + """Backward pass wrapper that calls the custom backward op.""" + if not unpermuted_act_grad.is_contiguous(): + unpermuted_act_grad = unpermuted_act_grad.contiguous() - dtype = TE_DType[unpermuted_act_grad.dtype] - inp, row_id_map, probs = ctx.saved_tensors + inp, row_id_map, probs = ctx.saved_tensors - act_grad = None + act_grad, prob_grad = torch.ops.te_moe.unpermute_index_map_bwd( + unpermuted_act_grad, inp, row_id_map, probs + ) + + if not ctx.needs_probs_grad: prob_grad = None - if ctx.needs_input_grad[0]: - act_grad, prob_grad = tex.moe_unpermute_bwd( - unpermuted_act_grad, inp, dtype, row_id_map, probs - ) - if not ctx.needs_input_grad[2]: - prob_grad = None - return act_grad, None, prob_grad + return act_grad, None, prob_grad, None, None -class _moe_permute_mask_map(torch.autograd.Function): - """functional Permute with mask router map""" +moe_unpermute_index_map_forward.register_autograd( + _moe_unpermute_index_map_backward_wrapper, + setup_context=_moe_unpermute_index_map_setup_context, +) - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - routing_map: torch.Tensor, - num_out_tokens: int, - probs: torch.Tensor, - pad_offsets: Optional[torch.Tensor], - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring - if not inp.numel(): - ctx.probs = probs - return inp, torch.tensor([], device=inp.device), torch.tensor([], device=inp.device) - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert routing_map.is_cuda, "TransformerEngine needs CUDA." - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." +# ===================== _moe_permute_mask_map custom ops ===================== - assert inp.size(0) == routing_map.size(0), "Permute not possible" - num_tokens, hidden_size = inp.size() - num_experts = routing_map.size(1) - assert ( - num_out_tokens is not None - ), "num_out_tokens must be provided to the fused permute function." - - row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) - - fp8 = isinstance(inp, QuantizedTensor) - per_tensor_recipe = isinstance(inp, Float8Tensor) - blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) - mxfp8_recipe = isinstance(inp, MXFP8Tensor) - - if fp8: - fp8_dtype = inp._fp8_dtype - fake_dtype = inp.dtype - # blockwise scaling - if blockwise_recipe: - fp8_scale = inp._rowwise_scale_inv.T.contiguous() - scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - inp = inp._rowwise_data - # mxfp8 scaling - elif mxfp8_recipe: - fp8_scale = inp._rowwise_scale_inv.contiguous() - scale_hidden_dim = fp8_scale.shape[1] - assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - inp = inp._rowwise_data - # per-tensor scaling - elif per_tensor_recipe: - # Kernel does not need scale in per-tensor scaling - fp8_scale = None - scale_hidden_dim = None - fp8_scale_inv = inp._scale_inv - inp = inp._data - else: - raise ValueError("Unsupported FP8 recipe") - else: + +@torch.library.custom_op("te_moe::permute_mask_map_fwd", mutates_args=[]) +def moe_permute_mask_map_forward( + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + probs: Optional[torch.Tensor], + pad_offsets: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for MoE permute with mask router map.""" + num_tokens, hidden_size = inp.size() + num_experts = routing_map.size(1) + + row_id_map = triton_permutation.make_row_id_map(routing_map, num_tokens, num_experts) + + # FP8 handling + fp8 = isinstance(inp, QuantizedTensor) + per_tensor_recipe = isinstance(inp, Float8Tensor) + blockwise_recipe = isinstance(inp, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(inp, MXFP8Tensor) + + if fp8: + fp8_dtype = inp._fp8_dtype + fake_dtype = inp.dtype + if blockwise_recipe: + fp8_scale = inp._rowwise_scale_inv.T.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + elif mxfp8_recipe: + fp8_scale = inp._rowwise_scale_inv.contiguous() + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + inp = inp._rowwise_data + elif per_tensor_recipe: fp8_scale = None - fp8_dtype = None scale_hidden_dim = None + fp8_scale_inv = inp._scale_inv + inp = inp._data + else: + raise ValueError("Unsupported FP8 recipe") + else: + fp8_scale = None + fp8_dtype = None + scale_hidden_dim = None + + output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( + inp, + row_id_map, + probs, + fp8_scale, + pad_offsets, + num_tokens, + num_experts, + num_out_tokens, + hidden_size, + scale_hidden_dim, + ) - output, permuted_scale, permuted_probs = triton_permutation.permute_with_mask_map( - inp, - row_id_map, - probs, - fp8_scale, - pad_offsets, - num_tokens, - num_experts, - num_out_tokens, - hidden_size, - scale_hidden_dim, - ) + if fp8: + if per_tensor_recipe: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + output = Float8BlockwiseQTensor( + shape=output.shape, + dtype=fake_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=output.requires_grad, + ) + elif mxfp8_recipe: + output = MXFP8Tensor( + shape=output.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=output, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=output.requires_grad, + with_gemm_swizzled_scales=False, + ) - if fp8: - if per_tensor_recipe: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) - elif blockwise_recipe: - output = Float8BlockwiseQTensor( - shape=output.shape, - dtype=fake_dtype, - rowwise_data=output, - rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, - quantizer=None, - is_2D_scaled=False, - requires_grad=output.requires_grad, - ) - elif mxfp8_recipe: - output = MXFP8Tensor( - shape=output.shape, - dtype=fake_dtype, - fp8_dtype=fp8_dtype, - rowwise_data=output, - rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - quantizer=None, - requires_grad=output.requires_grad, - with_gemm_swizzled_scales=False, - ) + # If permuted_probs is None, return empty tensor (custom ops need concrete tensors) + if permuted_probs is None: + permuted_probs = torch.empty(0, device=inp.device) - ctx.save_for_backward(row_id_map, pad_offsets) - ctx.num_experts = num_experts - ctx.num_tokens = num_tokens - ctx.hidden_size = hidden_size - return output, row_id_map, permuted_probs - - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - _, - permuted_probs_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, ctx.probs, None - - act_grad = None + return output, row_id_map, permuted_probs + + +@moe_permute_mask_map_forward.register_fake +def _moe_permute_mask_map_forward_fake( + inp: torch.Tensor, + routing_map: torch.Tensor, + num_out_tokens: int, + probs: Optional[torch.Tensor], + pad_offsets: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake implementation for shape inference.""" + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] + num_experts = routing_map.shape[1] + # row_id_map: (num_tokens, num_experts * 2 + 1) + fake_output = torch.empty((num_out_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + fake_row_id_map = torch.empty( + (num_tokens, num_experts * 2 + 1), dtype=torch.int32, device=inp.device + ) + if probs is not None: + fake_permuted_probs = torch.empty((num_out_tokens,), dtype=probs.dtype, device=inp.device) + else: + fake_permuted_probs = torch.empty(0, device=inp.device) + return fake_output, fake_row_id_map, fake_permuted_probs + + +@torch.library.custom_op("te_moe::permute_mask_map_bwd", mutates_args=[]) +def moe_permute_mask_map_backward( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE permute with mask router map.""" + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( + permuted_act_grad, + row_id_map, + None, + permuted_probs_grad, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + if probs_grad is None: + probs_grad = torch.empty(0, device=permuted_act_grad.device) + return act_grad, probs_grad + + +@moe_permute_mask_map_backward.register_fake +def _moe_permute_mask_map_backward_fake( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference.""" + act_grad = torch.empty( + (num_tokens, hidden_size), dtype=permuted_act_grad.dtype, device=permuted_act_grad.device + ) + if permuted_probs_grad is not None: + probs_grad = torch.empty( + (num_tokens, num_experts), + dtype=permuted_probs_grad.dtype, + device=permuted_act_grad.device, + ) + else: + probs_grad = torch.empty(0, device=permuted_act_grad.device) + return act_grad, probs_grad + + +def _moe_permute_mask_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, routing_map, num_out_tokens, probs, pad_offsets = inputs + output_tensor, row_id_map, permuted_probs = output + ctx.save_for_backward(row_id_map, pad_offsets) + ctx.num_experts = routing_map.size(1) + ctx.num_tokens = inp.size(0) + ctx.hidden_size = inp.size(1) + ctx.needs_probs_grad = probs is not None and probs.requires_grad + + +def _moe_permute_mask_map_backward_wrapper(ctx, grad_output, grad_row_id_map, grad_permuted_probs): + """Backward wrapper calling the custom backward op.""" + assert not isinstance( + grad_output, QuantizedTensor + ), "The backward of moe_permute does not support FP8." + + row_id_map, pad_offsets = ctx.saved_tensors + + # Pass permuted_probs_grad only if it has content + probs_grad_input = grad_permuted_probs if grad_permuted_probs.numel() > 0 else None + + act_grad, probs_grad = torch.ops.te_moe.permute_mask_map_bwd( + grad_output, + probs_grad_input, + row_id_map, + pad_offsets, + ctx.num_tokens, + ctx.num_experts, + ctx.hidden_size, + ) + + if not ctx.needs_probs_grad or probs_grad.numel() == 0: probs_grad = None - if ctx.needs_input_grad[0]: - row_id_map, pad_offsets = ctx.saved_tensors - assert not isinstance( - permuted_act_grad, QuantizedTensor - ), "The backward of moe_permute does not support FP8." - act_grad, probs_grad = triton_permutation.unpermute_with_mask_map( - permuted_act_grad, - row_id_map, - None, - permuted_probs_grad, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.hidden_size, + + return act_grad, None, None, probs_grad, None + + +moe_permute_mask_map_forward.register_autograd( + _moe_permute_mask_map_backward_wrapper, + setup_context=_moe_permute_mask_map_setup_context, +) + + +# ===================== _moe_unpermute_mask_map custom ops ===================== + + +@torch.library.custom_op("te_moe::unpermute_mask_map_fwd", mutates_args=[]) +def moe_unpermute_mask_map_forward( + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, + pad_offsets: Optional[torch.Tensor], +) -> torch.Tensor: + """Forward pass for MoE unpermute with mask router map.""" + assert not isinstance( + inp, QuantizedTensor + ), "The forward of moe_unpermute does not support FP8." + unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( + inp, + row_id_map, + merging_probs, + None, + pad_offsets, + num_tokens, + num_experts, + hidden_size, + ) + return unpermuted_output + + +@moe_unpermute_mask_map_forward.register_fake +def _moe_unpermute_mask_map_forward_fake( + inp: torch.Tensor, + row_id_map: torch.Tensor, + merging_probs: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + hidden_size: int, + pad_offsets: Optional[torch.Tensor], +) -> torch.Tensor: + """Fake implementation for shape inference.""" + return torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + + +@torch.library.custom_op("te_moe::unpermute_mask_map_bwd_with_probs", mutates_args=[]) +def moe_unpermute_mask_map_backward_with_probs( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + merging_probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE unpermute with merging probs.""" + act_grad, probs_grad = triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( + unpermuted_act_grad, + row_id_map, + fwd_input, + merging_probs, + pad_offsets, + num_tokens, + num_experts, + num_permuted_tokens, + hidden_size, + ) + return act_grad, probs_grad + + +@moe_unpermute_mask_map_backward_with_probs.register_fake +def _moe_unpermute_mask_map_bwd_with_probs_fake( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + fwd_input: torch.Tensor, + merging_probs: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference with merging probs.""" + act_grad = torch.empty( + (num_permuted_tokens, hidden_size), + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, + ) + probs_grad = torch.empty( + (num_tokens, num_experts), + dtype=merging_probs.dtype, + device=unpermuted_act_grad.device, + ) + return act_grad, probs_grad + + +@torch.library.custom_op("te_moe::unpermute_mask_map_bwd_no_probs", mutates_args=[]) +def moe_unpermute_mask_map_backward_no_probs( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> torch.Tensor: + """Backward pass for MoE unpermute without merging probs (permute grad back).""" + # FP8 handling + fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) + per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) + blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) + mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) + + if fp8: + fp8_dtype = unpermuted_act_grad._fp8_dtype + fake_dtype = unpermuted_act_grad.dtype + if per_tensor_recipe: + fp8_scale = None + scale_hidden_dim = None + fp8_scale_inv = unpermuted_act_grad._scale_inv + unpermuted_act_grad = unpermuted_act_grad._data + elif blockwise_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + elif mxfp8_recipe: + fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() + unpermuted_act_grad = unpermuted_act_grad._rowwise_data + scale_hidden_dim = fp8_scale.shape[1] + assert num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" + else: + raise ValueError("Unsupported FP8 recipe") + else: + scale_hidden_dim = None + fp8_dtype = None + fp8_scale = None + + act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( + unpermuted_act_grad, + row_id_map, + None, + fp8_scale, + pad_offsets, + num_tokens, + num_experts, + num_permuted_tokens, + hidden_size, + scale_hidden_dim, + ) + + if fp8: + if per_tensor_recipe: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + elif blockwise_recipe: + act_grad = Float8BlockwiseQTensor( + shape=act_grad.shape, + dtype=fake_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.T.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + fp8_dtype=fp8_dtype, + quantizer=None, + is_2D_scaled=False, + requires_grad=act_grad.requires_grad, + ) + elif mxfp8_recipe: + act_grad = MXFP8Tensor( + shape=act_grad.shape, + dtype=fake_dtype, + fp8_dtype=fp8_dtype, + rowwise_data=act_grad, + rowwise_scale_inv=permuted_scale.contiguous(), + columnwise_data=None, + columnwise_scale_inv=None, + quantizer=None, + requires_grad=act_grad.requires_grad, + with_gemm_swizzled_scales=False, ) - if not ctx.needs_input_grad[3]: - probs_grad = None - return act_grad, None, None, probs_grad, None - - -class _moe_unpermute_mask_map(torch.autograd.Function): - """functional Unpermute with mask router map""" - - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - row_id_map: torch.Tensor, - merging_probs: Optional[torch.Tensor], - restore_shape: Optional[torch.Size], - pad_offsets: Optional[torch.Tensor], - ) -> torch.Tensor: - # pylint: disable=missing-function-docstring - if not inp.numel(): - ctx.merging_probs = merging_probs - return inp - if restore_shape is None: - restore_shape = inp.shape - num_tokens, hidden_size = restore_shape - num_experts = (row_id_map.size(1) - 1) // 2 + return act_grad - with_probs = merging_probs is not None - if with_probs: - assert merging_probs.is_cuda, "TransformerEngine needs CUDA." - # Device check - assert inp.is_cuda, "TransformerEngine needs CUDA." - assert row_id_map.is_cuda, "TransformerEngine needs CUDA." - if pad_offsets is not None: - assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." +@moe_unpermute_mask_map_backward_no_probs.register_fake +def _moe_unpermute_mask_map_bwd_no_probs_fake( + unpermuted_act_grad: torch.Tensor, + row_id_map: torch.Tensor, + pad_offsets: Optional[torch.Tensor], + num_tokens: int, + num_experts: int, + num_permuted_tokens: int, + hidden_size: int, +) -> torch.Tensor: + """Fake for backward shape inference without probs.""" + return torch.empty( + (num_permuted_tokens, hidden_size), + dtype=unpermuted_act_grad.dtype, + device=unpermuted_act_grad.device, + ) + +def _moe_unpermute_mask_map_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, row_id_map, merging_probs, num_tokens, num_experts, hidden_size, pad_offsets = inputs + ctx.num_experts = num_experts + ctx.num_tokens = num_tokens + ctx.num_permuted_tokens = inp.size(0) + ctx.hidden_size = hidden_size + ctx.with_probs = merging_probs is not None + if ctx.with_probs: + ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) + ctx.needs_probs_grad = merging_probs.requires_grad + else: + ctx.save_for_backward(row_id_map, pad_offsets) + ctx.needs_probs_grad = False + + +def _moe_unpermute_mask_map_backward_wrapper(ctx, unpermuted_act_grad): + """Backward wrapper calling the appropriate custom backward op.""" + act_grad = None + probs_grad = None + + if ctx.with_probs: + fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors assert not isinstance( - inp, QuantizedTensor - ), "The forward of moe_unpermute does not support FP8." - unpermuted_output, _ = triton_permutation.unpermute_with_mask_map( - inp, + unpermuted_act_grad, QuantizedTensor + ), "The backward of moe_unpermute with merging probs does not support FP8." + act_grad, probs_grad = torch.ops.te_moe.unpermute_mask_map_bwd_with_probs( + unpermuted_act_grad, row_id_map, + fwd_input, merging_probs, - None, pad_offsets, - num_tokens, - num_experts, - hidden_size, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, + ) + else: + row_id_map, pad_offsets = ctx.saved_tensors + act_grad = torch.ops.te_moe.unpermute_mask_map_bwd_no_probs( + unpermuted_act_grad, + row_id_map, + pad_offsets, + ctx.num_tokens, + ctx.num_experts, + ctx.num_permuted_tokens, + ctx.hidden_size, ) - if with_probs: - ctx.save_for_backward(inp, row_id_map, merging_probs, pad_offsets) - else: - ctx.save_for_backward(row_id_map, pad_offsets) - ctx.num_experts = num_experts - ctx.num_tokens = num_tokens - ctx.num_permuted_tokens = inp.size(0) - ctx.hidden_size = hidden_size - ctx.with_probs = with_probs - return unpermuted_output - - @staticmethod - def backward(ctx, unpermuted_act_grad): - # pylint: disable=missing-function-docstring - if not unpermuted_act_grad.numel(): - return unpermuted_act_grad, None, ctx.merging_probs, None, None - - act_grad = None + if not ctx.needs_probs_grad: probs_grad = None - if ctx.needs_input_grad[0]: - if ctx.with_probs: - fwd_input, row_id_map, merging_probs, pad_offsets = ctx.saved_tensors - else: - row_id_map, pad_offsets = ctx.saved_tensors - - fp8 = isinstance(unpermuted_act_grad, QuantizedTensor) - per_tensor_recipe = isinstance(unpermuted_act_grad, Float8Tensor) - blockwise_recipe = isinstance(unpermuted_act_grad, Float8BlockwiseQTensor) - mxfp8_recipe = isinstance(unpermuted_act_grad, MXFP8Tensor) - - if fp8: - fp8_dtype = unpermuted_act_grad._fp8_dtype - fake_dtype = unpermuted_act_grad.dtype - # per-tensor scaling - if per_tensor_recipe: - # Kernel does not need scale in per-tensor scaling - fp8_scale = None - scale_hidden_dim = None - fp8_scale_inv = unpermuted_act_grad._scale_inv - unpermuted_act_grad = unpermuted_act_grad._data - # blockwise scaling - elif blockwise_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.T.contiguous() - unpermuted_act_grad = unpermuted_act_grad._rowwise_data - scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - # mxfp8 scaling - elif mxfp8_recipe: - fp8_scale = unpermuted_act_grad._rowwise_scale_inv.contiguous() - unpermuted_act_grad = unpermuted_act_grad._rowwise_data - scale_hidden_dim = fp8_scale.shape[1] - assert ctx.num_tokens == fp8_scale.shape[0], "scale and input shape mismatch" - else: - raise ValueError("Unsupported FP8 recipe") - else: - scale_hidden_dim = None - fp8_dtype = None - fp8_scale = None - - if ctx.with_probs: - assert ( - not fp8 - ), "The backward of moe_unpermute with merging probs does not support FP8." - act_grad, probs_grad = ( - triton_permutation.unpermute_with_mask_map_bwd_with_merging_probs( - unpermuted_act_grad, - row_id_map, - fwd_input, - merging_probs, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - ) - ) - else: - act_grad, permuted_scale, _ = triton_permutation.permute_with_mask_map( - unpermuted_act_grad, - row_id_map, - None, - fp8_scale, - pad_offsets, - ctx.num_tokens, - ctx.num_experts, - ctx.num_permuted_tokens, - ctx.hidden_size, - scale_hidden_dim, - ) - if fp8: - if per_tensor_recipe: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) - elif blockwise_recipe: - act_grad = Float8BlockwiseQTensor( - shape=act_grad.shape, - dtype=fake_dtype, - rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.T.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - fp8_dtype=fp8_dtype, - quantizer=None, - is_2D_scaled=False, - requires_grad=act_grad.requires_grad, - ) - elif mxfp8_recipe: - act_grad = MXFP8Tensor( - shape=act_grad.shape, - dtype=fake_dtype, - fp8_dtype=fp8_dtype, - rowwise_data=act_grad, - rowwise_scale_inv=permuted_scale.contiguous(), - columnwise_data=None, - columnwise_scale_inv=None, - quantizer=None, - requires_grad=act_grad.requires_grad, - with_gemm_swizzled_scales=False, - ) - - if not ctx.needs_input_grad[2]: - probs_grad = None - return act_grad, None, probs_grad, None, None + return act_grad, None, probs_grad, None, None, None, None + + +moe_unpermute_mask_map_forward.register_autograd( + _moe_unpermute_mask_map_backward_wrapper, + setup_context=_moe_unpermute_mask_map_setup_context, +) + +# Register all te_moe custom ops as passthrough in QuantizedTensor.__torch_dispatch__ +# so that FP8 tensors are not unwrapped before entering these ops. +_quantized_tensor_passthrough_ops.update( + { + torch.ops.te_moe.permute_mask_map_fwd.default, + torch.ops.te_moe.permute_mask_map_bwd.default, + torch.ops.te_moe.unpermute_mask_map_fwd.default, + torch.ops.te_moe.unpermute_mask_map_bwd_with_probs.default, + torch.ops.te_moe.unpermute_mask_map_bwd_no_probs.default, + } +) def moe_permute( @@ -547,10 +781,21 @@ def moe_permute( Options are: 'mask', 'index'. Refer to `routing_map` for more details. """ + if not inp.numel(): + return inp, torch.tensor([], device=inp.device) + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert routing_map.is_cuda, "TransformerEngine needs CUDA." + assert inp.size(0) == routing_map.size(0), "Permute not possible" + if routing_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `routing_map` of Permute is {routing_map.dtype}! " + "The recommended type is torch.int32." + ) + routing_map = routing_map.to(torch.int32) if map_type == "index": - return _moe_permute_index_map.apply(inp, routing_map, num_out_tokens, max_token_num) + return torch.ops.te_moe.permute_index_map(inp, routing_map, num_out_tokens, max_token_num) if map_type == "mask": - output, row_id_map, _ = _moe_permute_mask_map.apply( + output, row_id_map, _ = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, num_out_tokens, None, None ) return output, row_id_map @@ -584,7 +829,15 @@ def moe_permute_with_probs( The effective output token count, representing the number of tokens not dropped. By default, set to '-1', meaning no tokens are dropped. """ - output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + if not inp.numel(): + # Keep probs in autograd graph so that probs.grad is an empty tensor + # instead of None after backward (backward compatibility). + return ( + inp + probs.sum() * 0, + probs.sum(dim=1), + torch.tensor([], device=inp.device), + ) + output, row_id_map, permuted_probs = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, num_out_tokens, probs, None ) return output, permuted_probs, row_id_map @@ -640,7 +893,7 @@ def moe_permute_and_pad_with_probs( [torch.zeros(1, dtype=cum_pad.dtype, device=inp.device), cum_pad[:-1]] ) - output, row_id_map, permuted_probs = _moe_permute_mask_map.apply( + output, row_id_map, permuted_probs = torch.ops.te_moe.permute_mask_map_fwd( inp, routing_map, target_tokens_per_expert.sum().item(), probs, pad_offsets ) return output, permuted_probs, row_id_map, pad_offsets, target_tokens_per_expert @@ -690,113 +943,249 @@ def moe_unpermute( warnings.warn("probs kwarg is deprecated. Use merging_probs kwarg instead.") merging_probs = probs if map_type == "index": - return _moe_unpermute_index_map.apply(inp, row_id_map, merging_probs) - if map_type == "mask": - return _moe_unpermute_mask_map.apply( - inp, row_id_map, merging_probs, restore_shape, pad_offsets - ) - raise ValueError("map_type should be one of 'mask' or 'index'") + # Empty input check + if not inp.numel(): + return inp + # Normalize probs + if merging_probs is not None: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." + if merging_probs.dtype != torch.float32: + warnings.warn( + f"The data type of the input `probs` of Unpermute is {merging_probs.dtype}! " + "The recommended type is torch.float32." + ) + merging_probs = merging_probs.to(torch.float32) + num_tokens = merging_probs.size(0) + topK = merging_probs.size(1) + else: + num_tokens = row_id_map.size(0) + topK = 1 + merging_probs = torch.empty(0, device=inp.device) -class _moe_chunk_sort(torch.autograd.Function): - """functional MoE chunk permute""" + # Device check + assert inp.is_cuda, "TransformerEngine needs CUDA." + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if row_id_map.dtype != torch.int32: + warnings.warn( + f"The data type of the input `row_id_map` of Unpermute is {row_id_map.dtype}! " + "The recommended type is torch.int32." + ) + row_id_map = row_id_map.to(torch.int32) - @staticmethod - def forward( - ctx, - inp: torch.Tensor, - split_sizes: torch.Tensor, - sorted_idxs: torch.Tensor, - probs: torch.Tensor, - ) -> Tuple[torch.Tensor, torch.Tensor]: - # pylint: disable=missing-function-docstring + return torch.ops.te_moe.unpermute_index_map_fwd( + inp, row_id_map, merging_probs, num_tokens, topK + ) + if map_type == "mask": if not inp.numel(): - return inp, probs + # Keep merging_probs in autograd graph so that probs.grad is an empty + # tensor instead of None after backward (backward compatibility). + if merging_probs is not None: + return inp + merging_probs.sum() * 0 + return inp + + if restore_shape is None: + restore_shape = inp.shape + num_tokens, hidden_size = restore_shape + num_experts = (row_id_map.size(1) - 1) // 2 + if merging_probs is not None: + assert merging_probs.is_cuda, "TransformerEngine needs CUDA." assert inp.is_cuda, "TransformerEngine needs CUDA." - assert split_sizes.is_cuda, "TransformerEngine needs CUDA." - assert sorted_idxs.is_cuda, "TransformerEngine needs CUDA." - if probs is not None: - assert probs.is_cuda, "TransformerEngine needs CUDA." - - num_tokens, hidden_size = inp.shape - num_splits = split_sizes.size(0) - assert num_splits == sorted_idxs.size(0) - - fp8 = isinstance(inp, Float8Tensor) - if fp8: - fp8_dtype = inp._fp8_dtype - fp8_scale_inv = inp._scale_inv - fake_dtype = inp.dtype - inp = inp._data + assert row_id_map.is_cuda, "TransformerEngine needs CUDA." + if pad_offsets is not None: + assert pad_offsets.is_cuda, "TransformerEngine needs CUDA." - row_id_map = triton_permutation.make_chunk_sort_map( - split_sizes, - sorted_idxs, - num_tokens, - num_splits, - ) - output, permuted_probs = triton_permutation.sort_chunks_by_map( + return torch.ops.te_moe.unpermute_mask_map_fwd( inp, row_id_map, - probs, + merging_probs, num_tokens, + num_experts, hidden_size, - is_forward=True, + pad_offsets, + ) + raise ValueError("map_type should be one of 'mask' or 'index'") + + +# ===================== _moe_chunk_sort custom ops ===================== + + +@torch.library.custom_op("te_moe::chunk_sort_fwd", mutates_args=[]) +def moe_chunk_sort_forward( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + probs: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Forward pass for MoE chunk sort. Returns (output, permuted_probs, row_id_map).""" + num_tokens, hidden_size = inp.shape + num_splits = split_sizes.size(0) + + fp8 = isinstance(inp, Float8Tensor) + if fp8: + fp8_dtype = inp._fp8_dtype + fp8_scale_inv = inp._scale_inv + fake_dtype = inp.dtype + inp = inp._data + + row_id_map = triton_permutation.make_chunk_sort_map( + split_sizes, + sorted_idxs, + num_tokens, + num_splits, + ) + output, permuted_probs = triton_permutation.sort_chunks_by_map( + inp, + row_id_map, + probs, + num_tokens, + hidden_size, + is_forward=True, + ) + if fp8: + output = Float8Tensor( + data=output, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=output.shape, + dtype=fake_dtype, ) - if fp8: - output = Float8Tensor( - data=output, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=output.shape, - dtype=fake_dtype, - ) - ctx.save_for_backward(row_id_map) - ctx.num_tokens = num_tokens - ctx.hidden_size = hidden_size - return output, permuted_probs - - @staticmethod - def backward( - ctx, - permuted_act_grad: torch.Tensor, - permuted_probs_grad: torch.Tensor, - ) -> Tuple[torch.Tensor, ...]: - # pylint: disable=missing-function-docstring - if not permuted_act_grad.numel(): - return permuted_act_grad, None, None, permuted_probs_grad - - act_grad = None + if permuted_probs is None: + permuted_probs = torch.empty(0, device=output.device) + + return output, permuted_probs, row_id_map + + +@moe_chunk_sort_forward.register_fake +def _moe_chunk_sort_forward_fake( + inp: torch.Tensor, + split_sizes: torch.Tensor, + sorted_idxs: torch.Tensor, + probs: Optional[torch.Tensor], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Fake for shape inference.""" + num_tokens = inp.shape[0] + hidden_size = inp.shape[1] + fake_output = torch.empty((num_tokens, hidden_size), dtype=inp.dtype, device=inp.device) + if probs is not None: + fake_probs = torch.empty((num_tokens,), dtype=probs.dtype, device=inp.device) + else: + fake_probs = torch.empty(0, device=inp.device) + # row_id_map: 1D, size num_tokens + fake_row_id_map = torch.empty((num_tokens,), dtype=torch.int32, device=inp.device) + return fake_output, fake_probs, fake_row_id_map + + +@torch.library.custom_op("te_moe::chunk_sort_bwd", mutates_args=[]) +def moe_chunk_sort_backward( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + num_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Backward pass for MoE chunk sort.""" + fp8 = isinstance(permuted_act_grad, Float8Tensor) + if fp8: + fp8_dtype = permuted_act_grad._fp8_dtype + fp8_scale_inv = permuted_act_grad._scale_inv + fake_dtype = permuted_act_grad.dtype + permuted_act_grad = permuted_act_grad._data + + act_grad, probs_grad = triton_permutation.sort_chunks_by_map( + permuted_act_grad, + row_id_map, + permuted_probs_grad, + num_tokens, + hidden_size, + is_forward=False, + ) + + if fp8: + act_grad = Float8Tensor( + data=act_grad, + fp8_dtype=fp8_dtype, + fp8_scale_inv=fp8_scale_inv, + shape=act_grad.shape, + dtype=fake_dtype, + ) + + if probs_grad is None: + probs_grad = torch.empty(0, device=act_grad.device) + + return act_grad, probs_grad + + +@moe_chunk_sort_backward.register_fake +def _moe_chunk_sort_backward_fake( + permuted_act_grad: torch.Tensor, + permuted_probs_grad: Optional[torch.Tensor], + row_id_map: torch.Tensor, + num_tokens: int, + hidden_size: int, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Fake for backward shape inference.""" + fake_act_grad = torch.empty( + (num_tokens, hidden_size), + dtype=permuted_act_grad.dtype, + device=permuted_act_grad.device, + ) + if permuted_probs_grad is not None: + fake_probs_grad = torch.empty( + (num_tokens,), + dtype=permuted_probs_grad.dtype, + device=permuted_act_grad.device, + ) + else: + fake_probs_grad = torch.empty(0, device=permuted_act_grad.device) + return fake_act_grad, fake_probs_grad + + +def _moe_chunk_sort_setup_context(ctx, inputs, output): + """Save context for backward pass.""" + inp, split_sizes, sorted_idxs, probs = inputs + output_tensor, permuted_probs, row_id_map = output + + ctx.save_for_backward(row_id_map) + ctx.num_tokens = inp.size(0) + ctx.hidden_size = inp.size(1) + ctx.needs_probs_grad = probs is not None and probs.requires_grad + + +def _moe_chunk_sort_backward_wrapper(ctx, permuted_act_grad, permuted_probs_grad, _row_id_map_grad): + """Backward wrapper calling the custom backward op.""" + (row_id_map,) = ctx.saved_tensors + + probs_grad_input = permuted_probs_grad if permuted_probs_grad.numel() > 0 else None + + act_grad, probs_grad = torch.ops.te_moe.chunk_sort_bwd( + permuted_act_grad, + probs_grad_input, + row_id_map, + ctx.num_tokens, + ctx.hidden_size, + ) + + if not ctx.needs_probs_grad or probs_grad.numel() == 0: probs_grad = None - if ctx.needs_input_grad[0]: - (row_id_map,) = ctx.saved_tensors - fp8 = isinstance(permuted_act_grad, Float8Tensor) - if fp8: - fp8_dtype = permuted_act_grad._fp8_dtype - fp8_scale_inv = permuted_act_grad._scale_inv - fake_dtype = permuted_act_grad.dtype - permuted_act_grad = permuted_act_grad._data - act_grad, probs_grad = triton_permutation.sort_chunks_by_map( - permuted_act_grad, - row_id_map, - permuted_probs_grad, - ctx.num_tokens, - ctx.hidden_size, - is_forward=False, - ) - if fp8: - act_grad = Float8Tensor( - data=act_grad, - fp8_dtype=fp8_dtype, - fp8_scale_inv=fp8_scale_inv, - shape=act_grad.shape, - dtype=fake_dtype, - ) - if not ctx.needs_input_grad[3]: - probs_grad = None - return act_grad, None, None, probs_grad + + return act_grad, None, None, probs_grad + + +moe_chunk_sort_forward.register_autograd( + _moe_chunk_sort_backward_wrapper, + setup_context=_moe_chunk_sort_setup_context, +) + +# Register chunk sort ops as passthrough in QuantizedTensor.__torch_dispatch__ +_quantized_tensor_passthrough_ops.update( + { + torch.ops.te_moe.chunk_sort_fwd.default, + torch.ops.te_moe.chunk_sort_bwd.default, + } +) def moe_sort_chunks_by_index( @@ -818,7 +1207,9 @@ def moe_sort_chunks_by_index( sorted_indices : torch.Tensor Chunk indices used to permute the chunks. """ - output, _ = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, None) + if not inp.numel(): + return inp + output, _, _ = torch.ops.te_moe.chunk_sort_fwd(inp, split_sizes, sorted_index, None) return output @@ -846,5 +1237,9 @@ def moe_sort_chunks_by_index_with_probs( sorted_indices : torch.Tensor Chunk indices used to permute the chunks. """ - output, permuted_probs = _moe_chunk_sort.apply(inp, split_sizes, sorted_index, probs) + if not inp.numel(): + return inp, probs + output, permuted_probs, _ = torch.ops.te_moe.chunk_sort_fwd( + inp, split_sizes, sorted_index, probs + ) return output, permuted_probs diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index d78677bc83..1bac4b2b53 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -21,6 +21,12 @@ ) +# Custom ops that should pass through __torch_dispatch__ without unwrapping +# QuantizedTensor subclasses (e.g. Float8Tensor). Register ops here that +# handle quantized tensors internally. +_quantized_tensor_passthrough_ops: set = set() + + class QuantizedTensorStorage: r"""Base class for all TensorStorage classes. @@ -516,6 +522,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): return func(t) return False # Or error out? + # Pass through registered custom ops without unwrapping + if func in _quantized_tensor_passthrough_ops: + if kwargs is None: + kwargs = {} + return super().__torch_dispatch__(func, types, args, kwargs) + def maybe_unwrap(arg): if isinstance(arg, QuantizedTensor): return arg.dequantize(dtype=arg.dtype)