Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 56 additions & 6 deletions tests/pytorch/test_permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _test_permutation_index_map(
num_out_tokens,
with_probs,
BENCHMARK=False,
use_torch_compile=False,
):
if not with_probs and topK > 1:
pytest.skip("Only permutations with topK=1 and without probabilities are supported.")
Expand Down Expand Up @@ -298,9 +299,28 @@ def _test_permutation_index_map(
te_permute_fwd_input.requires_grad_(True)
te_permute_bwd_input = pytorch_permute_bwd_input.detach()

te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index"
)
if use_torch_compile:
# Reset dynamo to avoid recompile limit across parametrized tests
torch._dynamo.reset()
# Disable donated buffers to allow retain_graph=True
import torch._functorch.config as functorch_config

old_donated_buffer = functorch_config.donated_buffer
functorch_config.donated_buffer = False

# Create a wrapper function for torch.compile
def permute_wrapper(inp, idx, num_out, max_token):
return te_permute(inp, idx, num_out, max_token, map_type="index")

# Compile with fullgraph=True
compiled_permute = torch.compile(permute_wrapper, fullgraph=True)
te_permute_output, row_id_map = compiled_permute(
te_permute_fwd_input, indices, num_out_tokens, -1
)
else:
te_permute_output, row_id_map = te_permute(
te_permute_fwd_input, indices, num_out_tokens, map_type="index"
)
te_permute_output.backward(te_permute_bwd_input, retain_graph=True)

te_probs = None
Expand All @@ -311,11 +331,23 @@ def _test_permutation_index_map(
te_unpermute_fwd_input.requires_grad_(True)
te_unpermute_bwd_input = pytorch_unpermute_bwd_input.detach()

te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
)
if use_torch_compile:
# Create a wrapper function for torch.compile
def unpermute_wrapper(inp, row_map, probs_val):
return te_unpermute(inp, row_map, probs_val, map_type="index")

# Compile with fullgraph=True
compiled_unpermute = torch.compile(unpermute_wrapper, fullgraph=True)
te_unpermute_output = compiled_unpermute(te_unpermute_fwd_input, row_id_map, te_probs)
else:
te_unpermute_output = te_unpermute(
te_unpermute_fwd_input, row_id_map, te_probs, map_type="index"
)
te_unpermute_output.backward(te_unpermute_bwd_input, retain_graph=True)

if use_torch_compile:
functorch_config.donated_buffer = old_donated_buffer

###################################################################################################################################
#
# Results Check
Expand Down Expand Up @@ -1647,14 +1679,22 @@ def perf_test_cuda_kernel(cuda_kernel_fn):
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("topK", [2, 5])
@pytest.mark.parametrize("num_out_tokens", [None, 2039])
@pytest.mark.parametrize("use_torch_compile", [False, True])
@pytest.mark.skipif(
torch.__version__ < "2",
reason="torch.compile not available - skipping torch.compile tests",
)
def test_permutation_index_map(
te_dtype,
num_tokens,
num_expert,
hidden_size,
topK,
num_out_tokens,
use_torch_compile,
):
if use_torch_compile and torch.__version__ < "2":
pytest.skip("torch.compile not available")
with_probs = True
BENCHMARK = False

Expand All @@ -1667,6 +1707,7 @@ def test_permutation_index_map(
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
use_torch_compile=use_torch_compile,
)


Expand Down Expand Up @@ -1875,12 +1916,20 @@ def test_permutation_mask_map_fp8(
@pytest.mark.parametrize("num_tokens", [4096])
@pytest.mark.parametrize("num_expert", [7, 16])
@pytest.mark.parametrize("hidden_size", [4096])
@pytest.mark.parametrize("use_torch_compile", [False, True])
@pytest.mark.skipif(
torch.__version__ < "2",
reason="torch.compile not available - skipping torch.compile tests",
)
def test_permutation_index_map_topk1_no_probs(
te_dtype,
num_tokens,
num_expert,
hidden_size,
use_torch_compile,
):
if use_torch_compile and torch.__version__ < "2":
pytest.skip("torch.compile not available")
topK = 1
num_out_tokens = None
with_probs = False
Expand All @@ -1895,6 +1944,7 @@ def test_permutation_index_map_topk1_no_probs(
num_out_tokens=num_out_tokens,
with_probs=with_probs,
BENCHMARK=BENCHMARK,
use_torch_compile=use_torch_compile,
)


Expand Down
Loading