Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
a177cb4
add ParallelType::Stream lowering pass in host Ir for single device f…
samnordmann Mar 26, 2025
e886941
improve comments
samnordmann Mar 26, 2025
b6c54f2
fix rebase
samnordmann Apr 16, 2025
55f510f
Merge branch 'host_irs/LoadStore_Reduction_binaryOp_support' into hos…
samnordmann Apr 16, 2025
32a8d55
temporarily disable stream pass also in the python test
samnordmann Apr 16, 2025
afbd020
lint
samnordmann Apr 16, 2025
165bd1b
move stream_parallel_type to host_ir/pass folder
samnordmann Apr 16, 2025
59ba13c
Print all ID expressions in tv->printTransforms (#4258)
jacobhinkle Apr 16, 2025
85a9463
InsertReshardingsPass decomposes matmul/linear+ReduceScatter. (#4239)
wujingyue Apr 16, 2025
34fa83b
Create kir::Continue for persistent grid short-circuit (#4260)
rdspring1 Apr 16, 2025
5ecf7fe
Remove several uses of NVFUSER_DISTRIBUTED (#4255)
wujingyue Apr 17, 2025
1484997
warp specializied tma persistent kernel, step-2, use TMA load (#4240)
liqiangxl Apr 17, 2025
9b9cd8f
Fix scheduling of split-K with smem_epilogue on Hopper (#4257)
jacobhinkle Apr 17, 2025
1bc13d8
Add NVFUSER_DUMP=sass_to_file option (#4263)
jacobhinkle Apr 17, 2025
c477a3f
disable TmaWarpSpecializedTes, it needs predicate (#4267)
liqiangxl Apr 17, 2025
f3b22ab
Create separate AsyncGroup helpers for fence, commit, and wait operat…
rdspring1 Apr 18, 2025
c1d8423
Rename LoadWarp to AsyncWarp (#4270)
rdspring1 Apr 18, 2025
6851163
Remove stale exprs (#4268)
naoyam Apr 18, 2025
5494b0a
Unskip the DeepSeek test (#4273)
wujingyue Apr 18, 2025
1181eac
minor improvements and cleanup
samnordmann Apr 18, 2025
cad9bce
further refactor of stream pass
samnordmann Apr 18, 2025
7ae7c52
improve comments clarity
samnordmann Apr 18, 2025
6dd673f
more comments
samnordmann Apr 18, 2025
bc8c2cb
Merge branch 'host_irs/LoadStore_Reduction_binaryOp_support' into hos…
samnordmann Apr 18, 2025
ac7e09a
Adding IndexPutAccumulateOp (#4063)
jjsjann123 Apr 18, 2025
f24dc13
Minor fix on inline_ptx.cpp (#4278)
zasdfgbnm Apr 18, 2025
7bc8c17
Rename `ldstMBarrierMap` -> `mbarrierMap` (#4277)
zasdfgbnm Apr 18, 2025
ed68736
`shardAllLike` accepts a list of parallel types (#4254)
Priya2698 Apr 18, 2025
c969903
Tensor-parallelize the DeepSeek V3 transformer layer (#4062)
wujingyue Apr 19, 2025
bb5b38c
Disable two flaky tests to keep CI green (#4283)
wujingyue Apr 19, 2025
39aec16
Use `tcgen05` as namespace for TMem ld/st (#4279)
zasdfgbnm Apr 19, 2025
857d1df
Use mbarrier to sync Blackwell MMA (#4276)
zasdfgbnm Apr 20, 2025
5dac8bd
Add segmentation helper functions for edge processing (#4222)
csarofeen Apr 21, 2025
da72cae
Add separate files for mutil-wave and tma approaches (#4265)
liqiangxl Apr 21, 2025
5368ed0
Clean up multi-GPU python test fixtures (#4284)
wujingyue Apr 21, 2025
fb9b956
Split insertion_info into Pipeline and WarpSpecialized parts (#4275)
rdspring1 Apr 21, 2025
0b2f5a8
Create separate CircularBufferInserter for WarpSpecialized and Pipeli…
rdspring1 Apr 22, 2025
e697ec9
Add mutex guard to protect data race on options. (#4287)
jjsjann123 Apr 22, 2025
5f9cfb0
fix register spills in thread local outer reduction (#4184)
liqiangxl Apr 23, 2025
24a5cc9
Prefer static local to static global (#4289)
wujingyue Apr 23, 2025
096b681
Move `MarkAliasAnalysisPreparePass` before `propagateShardingsPass` (…
Priya2698 Apr 23, 2025
ab8846a
Replace an ad-hoc toposort with stablyOrderedExprs (#4285)
wujingyue Apr 23, 2025
d8b8cf4
check ID coverage for reference_tv in reduction scheduler (#4223)
jjsjann123 Apr 23, 2025
13a879c
Fix bug in stablyOrderedExprs (#4292)
wujingyue Apr 23, 2025
515e65e
Deallocate HostIr Op and Test (#4286)
nsarka Apr 23, 2025
eef49fc
renaming benchmark (#4293)
jjsjann123 Apr 23, 2025
ce3d607
Issue 4063 normalization scheduler (#4281)
jjsjann123 Apr 23, 2025
db90ef0
add HirAliasSelect
samnordmann Apr 23, 2025
e32653a
replace SelectOp by HirAliasSelect in stream lowering
samnordmann Apr 23, 2025
a50b53c
add cache for tensor slicing
samnordmann Apr 23, 2025
df447be
indexAccumulate python api (#4066)
jjsjann123 Apr 23, 2025
d01c5a2
separate out tensor allocation logic
samnordmann Apr 23, 2025
85f9894
Revert "Deallocate HostIr Op and Test" (#4303)
wujingyue Apr 23, 2025
25b7695
minor cleanup
samnordmann Apr 23, 2025
a958bfc
Forward full op (#4269)
naoyam Apr 24, 2025
c9d2cc9
Update propagateSharding preseg pass for DID loop split (#3838)
Priya2698 Apr 24, 2025
7477e4b
Extract benchmarking timers into a separate class (#4291)
Priya2698 Apr 24, 2025
b2a76e9
add comment
samnordmann Apr 24, 2025
20204fc
Merge branch 'host_irs/LoadStore_Reduction_binaryOp_support' into hos…
samnordmann Apr 24, 2025
e798e06
Refactor python build (#4193)
rdspring1 Apr 24, 2025
95c9bde
Change build directory for clang-tidy in lintrunner (#4309)
rdspring1 Apr 24, 2025
fadfde5
Switch axis we use to compute swizzled_tiles (#4311)
jacobhinkle Apr 25, 2025
87be6c3
Simplify some tests since sharding propagation is in place (#4304)
wujingyue Apr 25, 2025
07effe8
More precise WAR for resize vectorization (#4305)
naoyam Apr 25, 2025
3fe1c32
[Cuda Ipc] Add barrier at the end of `IpcHandleCache::exchangeHandles…
samnordmann Apr 27, 2025
cf5c6d2
[Host ir] support for set reduce and binary op (#4146)
samnordmann Apr 27, 2025
7f7caf5
change namespace of the optimization pass to hir
samnordmann Apr 27, 2025
bfc7ba8
add HirAliasSelect (#4301)
samnordmann Apr 27, 2025
b6213f3
Merge branch 'main' of github.com:NVIDIA/Fuser into host_irs/stream_l…
samnordmann Apr 27, 2025
7777fe0
lint
samnordmann Apr 27, 2025
e517bc3
fix merge
samnordmann Apr 27, 2025
35ff4da
empty commit to trigger the CI
samnordmann Apr 28, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ jobs:
tools/pip-install-things.sh &
source tools/setup-env.sh
wait
cd python
python setup.py build --cpp=23

dynamic-type-meson:
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,19 @@ jobs:

wait

# Go to python folder to build cmake files
cd python

# Run cmake build
python setup.py --cmake-only

# Generate csrc/serde/fusion_cache_generated.h
# NOTE: this might cause a compile of flatbuffers if it is missing
ninja -C build build_flatbuffer_config

# Return to root to run clang-tidy
cd ..

# Run lintrunner on all csrc files exclude benchmark and test folders
this_commit=$(git rev-parse HEAD)
git fetch origin main
Expand Down
12 changes: 8 additions & 4 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,24 @@ bin
# cmake build directory
build
.lintbin

# pip wheel directory
dist

nvfuser/version.py
nvfuser/include
nvfuser/lib
nvfuser/share
nvfuser/cmake

python/build
python/nvfuser/version.py
python/nvfuser/include
python/nvfuser/lib
python/nvfuser/share
python/nvfuser/cmake

.hypothesis
*.egg-info/
**/__pycache__
*/*.so
python/nvfuser/*.so

# Editor temporaries
*.swa
Expand Down
4 changes: 2 additions & 2 deletions .lintrunner.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ init_command = [
'python3',
'tools/linter/adapters/pip_init.py',
'--dry-run={{DRYRUN}}',
'flake8==6.0.0',
'flake8==6.1.0',
]


Expand Down Expand Up @@ -185,7 +185,7 @@ command = [
'python3',
'tools/linter/adapters/clangtidy_linter.py',
'--binary=~/.local/bin/clang-tidy',
'--build_dir=./build',
'--build_dir=./python/build',
'--',
'@{{PATHSFILE}}'
]
Expand Down
49 changes: 29 additions & 20 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR})
set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc")
set(NVFUSER_PYTHON_DIR "${NVFUSER_ROOT}/python")
set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party")

option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF)
Expand Down Expand Up @@ -212,6 +213,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp
${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp
${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp
${NVFUSER_SRCS_DIR}/rng.cpp
Expand Down Expand Up @@ -239,6 +241,9 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/communication.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_tma_ws.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_multi_wave.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_outer.cpp
${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp
Expand Down Expand Up @@ -289,13 +294,13 @@ endif()

if(BUILD_PYTHON)
list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/distributed_tensor.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp
${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp
${NVFUSER_SRCS_DIR}/python_frontend/segmentation.cpp
${NVFUSER_SRCS_DIR}/python_frontend/translation.cpp
${NVFUSER_SRCS_DIR}/python_frontend/translation_utils.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/distributed_tensor.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/fusion_cache.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/fusion_definition.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/fusion_state.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/segmentation.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/translation.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/translation_utils.cpp
${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp
)
endif()
Expand Down Expand Up @@ -331,6 +336,7 @@ if(NOT MSVC)
endif()

target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB")
target_include_directories(codegen_internal PUBLIC ${NVFUSER_PYTHON_DIR})
target_include_directories(codegen_internal SYSTEM PUBLIC
${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include
PRIVATE
Expand Down Expand Up @@ -457,31 +463,32 @@ if(BUILD_PYTHON)
# nvfuser python API sources
set(NVFUSER_PYTHON_SRCS)
list(APPEND NVFUSER_PYTHON_SRCS
${NVFUSER_SRCS_DIR}/python_frontend/multidevice_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/multidevice_bindings.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings_extension.cpp
${NVFUSER_PYTHON_DIR}/python_frontend/schedule_bindings.cpp
)

add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS})
target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_PYTHON_DIR})
target_include_directories(nvf_py_internal SYSTEM INTERFACE
${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include
)

# setup python API version
add_custom_command(
OUTPUT ${NVFUSER_ROOT}/nvfuser/version.py
OUTPUT ${NVFUSER_PYTHON_DIR}/nvfuser/version.py
COMMAND
"${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_ROOT}/tools/gen_nvfuser_version.py') .touch() \"
"${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py') .touch() \"
COMMAND
"${PYTHON_EXECUTABLE}" ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py
DEPENDS ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py
DEPENDS ${NVFUSER_ROOT}/version.txt
"${PYTHON_EXECUTABLE}" ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py
DEPENDS ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py
DEPENDS ${NVFUSER_PYTHON_DIR}/version.txt
WORKING_DIRECTORY ${NVFUSER_ROOT}/tools/
)
add_custom_target(
gen_nvfuser_version ALL
DEPENDS ${NVFUSER_ROOT}/nvfuser/version.py
DEPENDS ${NVFUSER_PYTHON_DIR}/nvfuser/version.py
)
add_dependencies(nvf_py_internal gen_nvfuser_version)

Expand Down Expand Up @@ -578,6 +585,7 @@ list(APPEND JIT_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp
${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp
${NVFUSER_ROOT}/tests/cpp/test_index_select.cpp
${NVFUSER_ROOT}/tests/cpp/test_index_put.cpp
${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp
${NVFUSER_ROOT}/tests/cpp/test_interval_analysis.cpp
${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp
Expand Down Expand Up @@ -732,16 +740,17 @@ if(BUILD_TEST)
list(APPEND HOSTIR_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp
${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp
)
add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "")
list(APPEND TEST_BINARIES test_host_ir)

if(BUILD_PYTHON)
set(PY_FRONTEND_TEST_SRCS)
list(APPEND PY_FRONTEND_TEST_SRCS
${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp
${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp
${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp
${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_cache.cpp
${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_definition.cpp
${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_record.cpp
)
add_test(test_python_frontend "${PY_FRONTEND_TEST_SRCS}" "")
list(APPEND TEST_BINARIES test_python_frontend)
Expand Down
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ PyPI: [https://pypi.org/project/nvfuser/](https://pypi.org/search/?q=nvfuser)

Docs: https://github.com/NVIDIA/Fuser/wiki

### Install From Source:
```bash
git clone https://github.com/NVIDIA/Fuser.git
cd Fuser
pip install -r python/requirements.txt

[DEPRECATED] `[MAX_JOBS] python setup.py develop [args]`
pip install --no-build-isolation -e python -v
```

Supported compilers:

**GCC:**
Expand Down
97 changes: 12 additions & 85 deletions benchmarks/python/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
from collections.abc import Iterable
import pytest_benchmark
import torch
from torch.autograd import DeviceType
from torch.profiler import profile, ProfilerActivity
from typing import List, Callable, Union
import numpy as np
from nvfuser import FusionDefinition, FusionCache
from nvfuser.pytorch_utils import DEVICE_PROPERTIES
import warnings
import thunder
from thunder.executors.nvfuserex import nvfuserex
from nvfuser.benchmark_utils import TorchProfileTimer, FusionProfileTimer

# These variables can be overwritten through CLI commands
# --benchmark-rounds=rounds --benchmark-warmup-rounds=warmup_rounds
Expand Down Expand Up @@ -102,34 +101,21 @@ def __init__(
self.benchmark: Underlying pytest-benchmark fixture with timer modified to use torchprofile_timer
self.current_time: Global montonic clock incremented based on elapsed CUDA time
"""

self.device = device
self.fd = None # Set through setup() for host benchmarking.
self.benchmark = benchmark_fixture

# Modify the default timer.
if device == "cuda":
# Initialize a Torch Profiler object
self.prof = profile(
activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU]
)
# Modify the default timer.
benchmark_fixture._timer = self.torchprofile_timer
benchmark_fixture._timer = TorchProfileTimer()
else:
benchmark_fixture._timer = self.fusionprofile_timer
benchmark_fixture._timer = FusionProfileTimer()
# Externally set the precision to avoid timer calibration. Since the timer uses CUDA times,
# calibration using subsequent timer calls produces invalid results.
# https://github.com/ionelmc/pytest-benchmark/blob/728752d2976ef53fde7e40beb3e55f09cf4d4736/src/pytest_benchmark/timers.py#L15
benchmark_fixture._precisions[benchmark_fixture._timer] = precision

self.benchmark = benchmark_fixture

# Global montonic clock
self.current_time = 0.0

# Specifies if the timer in host measurement is called at the start/finish of execution.
# Timings are measured at the end of execution.
self.execution_start = True

def __call__(self, function_to_benchmark: Callable, *args, **kwargs):
return self.benchmark(function_to_benchmark, *args, **kwargs)

Expand All @@ -138,73 +124,14 @@ def __getattr__(self, attr):
return getattr(self.benchmark, attr)
return super().__getattr__(attr)

def torchprofile_timer(self) -> float:
"""
Custom torchprofiler-based timer used by pytest-benchmark.
At every timer call, the profiler is stopped to compute the elapsed CUDA time
and the global clock is incremented. The profiler is restarted before returning to continue tracing.

Returns:
self.current_time: Global monotonic clock variable
"""
try:
self.prof.stop()
except AssertionError:
self.prof.start()
return self.current_time

prof_averages = self.prof.key_averages()
elapsed_cuda_time = self._get_kernel_time(prof_averages)
self._increment_global_time(elapsed_cuda_time)
# Clear the internal profiler object to avoid accumulating function events and then restart the profiler
# See PR: https://github.com/pytorch/pytorch/pull/125510
self.prof.profiler = None

return self.current_time

def fusionprofile_timer(self) -> float:
if not self.execution_start:
profile = self.fd.profile()
elapsed_host_time = profile.host_time_ms / 1e3
self._increment_global_time(elapsed_host_time)
self.execution_start = not self.execution_start
return self.current_time

def _get_kernel_time(
self, prof_averages: torch.autograd.profiler_util.EventList
) -> float:
"""
Arguments:
prof_averages: Output of self.prof.key_averages()
Returns:
time_value: Elapsed CUDA time in seconds.
"""
elapsed_cuda_time = 0
has_cuda_event = False
for event in prof_averages:
if event.device_type != DeviceType.CUDA:
continue
has_cuda_event = True
# Re: torch profiler API changes in https://github.com/pytorch/pytorch/pull/123247
elapsed_cuda_time = (
elapsed_cuda_time + event.self_device_time_total
if hasattr(event, "self_device_time_total")
else event.self_cuda_time_total
)
assert has_cuda_event, "No CUDA events found"
return elapsed_cuda_time / 1e6

def _increment_global_time(self, elapsed_time: float) -> None:
self.current_time += elapsed_time
# Set the fd object for fusion profiling.
# fd is returned by setup() for host benchmarking.
def set_fd(self, fd):
assert isinstance(self._timer, FusionProfileTimer)
self._timer.set_fd(fd)

def cleanup(self) -> None:
"""
Stops a running torchprofiler instance if found.
"""
try:
self.prof.stop()
except AssertionError:
pass
def cleanup(self):
self._timer.cleanup()

def set_metrics(
self,
Expand Down Expand Up @@ -374,7 +301,7 @@ def setup():
# The host_benchmark_fn uses the `fd` object returned from setup function.
def host_benchmark_fn(inputs, fd):
# Set the fd variable used to query the profile object
nvf_benchmark.fd = fd
nvf_benchmark.set_fd(fd)
return fd.execute(inputs, profile=True)

benchmark_fn = benchmark_fn if benchmark_fn is not None else host_benchmark_fn
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/python/test_cross_entropy_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@pytest.mark.parametrize(
"executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
)
def test_rope_fwd_benchmark(
def test_cross_entropy_fwd_benchmark(
benchmark,
variation: str,
executor: str,
Expand Down Expand Up @@ -52,7 +52,7 @@ def fwd_call(inp):
@pytest.mark.parametrize(
"executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"]
)
def test_rope_bwd_benchmark(
def test_cross_entropy_bwd_benchmark(
benchmark,
variation: str,
executor: str,
Expand Down
Loading