diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ae099633d78..1186bd19da9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -31,6 +31,7 @@ jobs: tools/pip-install-things.sh & source tools/setup-env.sh wait + cd python python setup.py build --cpp=23 dynamic-type-meson: diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index a9168522e11..2cc23e73227 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -51,6 +51,9 @@ jobs: wait + # Go to python folder to build cmake files + cd python + # Run cmake build python setup.py --cmake-only @@ -58,6 +61,9 @@ jobs: # NOTE: this might cause a compile of flatbuffers if it is missing ninja -C build build_flatbuffer_config + # Return to root to run clang-tidy + cd .. + # Run lintrunner on all csrc files exclude benchmark and test folders this_commit=$(git rev-parse HEAD) git fetch origin main diff --git a/.gitignore b/.gitignore index 89d7c587c4b..82a26694d84 100644 --- a/.gitignore +++ b/.gitignore @@ -4,20 +4,24 @@ bin # cmake build directory build .lintbin - -# pip wheel directory -dist - nvfuser/version.py nvfuser/include nvfuser/lib nvfuser/share nvfuser/cmake +python/build +python/nvfuser/version.py +python/nvfuser/include +python/nvfuser/lib +python/nvfuser/share +python/nvfuser/cmake + .hypothesis *.egg-info/ **/__pycache__ */*.so +python/nvfuser/*.so # Editor temporaries *.swa diff --git a/.lintrunner.toml b/.lintrunner.toml index 7fcac6c3c4d..f10037e39bb 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -16,7 +16,7 @@ init_command = [ 'python3', 'tools/linter/adapters/pip_init.py', '--dry-run={{DRYRUN}}', - 'flake8==6.0.0', + 'flake8==6.1.0', ] @@ -185,7 +185,7 @@ command = [ 'python3', 'tools/linter/adapters/clangtidy_linter.py', '--binary=~/.local/bin/clang-tidy', - '--build_dir=./build', + '--build_dir=./python/build', '--', '@{{PATHSFILE}}' ] diff --git a/CMakeLists.txt b/CMakeLists.txt index 226e1acc396..b09c6adb22f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,6 +10,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(NVFUSER_ROOT ${PROJECT_SOURCE_DIR}) set(NVFUSER_SRCS_DIR "${NVFUSER_ROOT}/csrc") +set(NVFUSER_PYTHON_DIR "${NVFUSER_ROOT}/python") set(NVFUSER_THIRD_PARTY_DIR "${NVFUSER_ROOT}/third_party") option(NVFUSER_STANDALONE_BUILD_WITH_UCC "" OFF) @@ -212,6 +213,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp @@ -239,6 +241,9 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/communication.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_tma_ws.cpp + ${NVFUSER_SRCS_DIR}/scheduler/normalization_inner_outer_multi_wave.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_outer.cpp ${NVFUSER_SRCS_DIR}/scheduler/normalization_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/pointwise.cpp @@ -289,13 +294,13 @@ endif() if(BUILD_PYTHON) list(APPEND NVFUSER_SRCS - ${NVFUSER_SRCS_DIR}/python_frontend/distributed_tensor.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_cache.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_definition.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/fusion_state.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/segmentation.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/translation.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/translation_utils.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/distributed_tensor.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_cache.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_definition.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/fusion_state.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/segmentation.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/translation.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/translation_utils.cpp ${NVFUSER_SRCS_DIR}/serde/fusion_record.cpp ) endif() @@ -331,6 +336,7 @@ if(NOT MSVC) endif() target_compile_definitions(codegen_internal PRIVATE "-DTORCH_CUDA_BUILD_MAIN_LIB") +target_include_directories(codegen_internal PUBLIC ${NVFUSER_PYTHON_DIR}) target_include_directories(codegen_internal SYSTEM PUBLIC ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include PRIVATE @@ -457,31 +463,32 @@ if(BUILD_PYTHON) # nvfuser python API sources set(NVFUSER_PYTHON_SRCS) list(APPEND NVFUSER_PYTHON_SRCS - ${NVFUSER_SRCS_DIR}/python_frontend/multidevice_bindings.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/python_bindings_extension.cpp - ${NVFUSER_SRCS_DIR}/python_frontend/schedule_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/multidevice_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/python_bindings_extension.cpp + ${NVFUSER_PYTHON_DIR}/python_frontend/schedule_bindings.cpp ) add_library(nvf_py_internal OBJECT ${NVFUSER_PYTHON_SRCS}) + target_include_directories(nvf_py_internal PUBLIC ${NVFUSER_PYTHON_DIR}) target_include_directories(nvf_py_internal SYSTEM INTERFACE ${CMAKE_SOURCE_DIR}/third_party/flatbuffers/include ) # setup python API version add_custom_command( - OUTPUT ${NVFUSER_ROOT}/nvfuser/version.py + OUTPUT ${NVFUSER_PYTHON_DIR}/nvfuser/version.py COMMAND - "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_ROOT}/tools/gen_nvfuser_version.py') .touch() \" + "${PYTHON_EXECUTABLE}" -c \"from pathlib import Path\; Path('${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py') .touch() \" COMMAND - "${PYTHON_EXECUTABLE}" ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py - DEPENDS ${NVFUSER_ROOT}/tools/gen_nvfuser_version.py - DEPENDS ${NVFUSER_ROOT}/version.txt + "${PYTHON_EXECUTABLE}" ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/tools/gen_nvfuser_version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/version.txt WORKING_DIRECTORY ${NVFUSER_ROOT}/tools/ ) add_custom_target( gen_nvfuser_version ALL - DEPENDS ${NVFUSER_ROOT}/nvfuser/version.py + DEPENDS ${NVFUSER_PYTHON_DIR}/nvfuser/version.py ) add_dependencies(nvf_py_internal gen_nvfuser_version) @@ -578,6 +585,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_indexing.cpp ${NVFUSER_ROOT}/tests/cpp/test_indexing_advanced.cpp ${NVFUSER_ROOT}/tests/cpp/test_index_select.cpp + ${NVFUSER_ROOT}/tests/cpp/test_index_put.cpp ${NVFUSER_ROOT}/tests/cpp/test_inlining.cpp ${NVFUSER_ROOT}/tests/cpp/test_interval_analysis.cpp ${NVFUSER_ROOT}/tests/cpp/test_iter_visitor.cpp @@ -732,6 +740,7 @@ if(BUILD_TEST) list(APPEND HOSTIR_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp ${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp + ${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp ) add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_host_ir) @@ -739,9 +748,9 @@ if(BUILD_TEST) if(BUILD_PYTHON) set(PY_FRONTEND_TEST_SRCS) list(APPEND PY_FRONTEND_TEST_SRCS - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp - ${NVFUSER_ROOT}/tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_cache.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_definition.cpp + ${NVFUSER_PYTHON_DIR}/tests/python_frontend/test_nvfuser_fusion_record.cpp ) add_test(test_python_frontend "${PY_FRONTEND_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_python_frontend) diff --git a/README.md b/README.md index 32c3bde8f4e..a00e09921c2 100644 --- a/README.md +++ b/README.md @@ -38,6 +38,16 @@ PyPI: [https://pypi.org/project/nvfuser/](https://pypi.org/search/?q=nvfuser) Docs: https://github.com/NVIDIA/Fuser/wiki +### Install From Source: +```bash +git clone https://github.com/NVIDIA/Fuser.git +cd Fuser +pip install -r python/requirements.txt + +[DEPRECATED] `[MAX_JOBS] python setup.py develop [args]` +pip install --no-build-isolation -e python -v +``` + Supported compilers: **GCC:** diff --git a/benchmarks/python/core.py b/benchmarks/python/core.py index c56d931cd35..735797c247b 100644 --- a/benchmarks/python/core.py +++ b/benchmarks/python/core.py @@ -4,8 +4,6 @@ from collections.abc import Iterable import pytest_benchmark import torch -from torch.autograd import DeviceType -from torch.profiler import profile, ProfilerActivity from typing import List, Callable, Union import numpy as np from nvfuser import FusionDefinition, FusionCache @@ -13,6 +11,7 @@ import warnings import thunder from thunder.executors.nvfuserex import nvfuserex +from nvfuser.benchmark_utils import TorchProfileTimer, FusionProfileTimer # These variables can be overwritten through CLI commands # --benchmark-rounds=rounds --benchmark-warmup-rounds=warmup_rounds @@ -102,20 +101,14 @@ def __init__( self.benchmark: Underlying pytest-benchmark fixture with timer modified to use torchprofile_timer self.current_time: Global montonic clock incremented based on elapsed CUDA time """ - self.device = device - self.fd = None # Set through setup() for host benchmarking. self.benchmark = benchmark_fixture + # Modify the default timer. if device == "cuda": - # Initialize a Torch Profiler object - self.prof = profile( - activities=[ProfilerActivity.CUDA, ProfilerActivity.CPU] - ) - # Modify the default timer. - benchmark_fixture._timer = self.torchprofile_timer + benchmark_fixture._timer = TorchProfileTimer() else: - benchmark_fixture._timer = self.fusionprofile_timer + benchmark_fixture._timer = FusionProfileTimer() # Externally set the precision to avoid timer calibration. Since the timer uses CUDA times, # calibration using subsequent timer calls produces invalid results. # https://github.com/ionelmc/pytest-benchmark/blob/728752d2976ef53fde7e40beb3e55f09cf4d4736/src/pytest_benchmark/timers.py#L15 @@ -123,13 +116,6 @@ def __init__( self.benchmark = benchmark_fixture - # Global montonic clock - self.current_time = 0.0 - - # Specifies if the timer in host measurement is called at the start/finish of execution. - # Timings are measured at the end of execution. - self.execution_start = True - def __call__(self, function_to_benchmark: Callable, *args, **kwargs): return self.benchmark(function_to_benchmark, *args, **kwargs) @@ -138,73 +124,14 @@ def __getattr__(self, attr): return getattr(self.benchmark, attr) return super().__getattr__(attr) - def torchprofile_timer(self) -> float: - """ - Custom torchprofiler-based timer used by pytest-benchmark. - At every timer call, the profiler is stopped to compute the elapsed CUDA time - and the global clock is incremented. The profiler is restarted before returning to continue tracing. - - Returns: - self.current_time: Global monotonic clock variable - """ - try: - self.prof.stop() - except AssertionError: - self.prof.start() - return self.current_time - - prof_averages = self.prof.key_averages() - elapsed_cuda_time = self._get_kernel_time(prof_averages) - self._increment_global_time(elapsed_cuda_time) - # Clear the internal profiler object to avoid accumulating function events and then restart the profiler - # See PR: https://github.com/pytorch/pytorch/pull/125510 - self.prof.profiler = None - - return self.current_time - - def fusionprofile_timer(self) -> float: - if not self.execution_start: - profile = self.fd.profile() - elapsed_host_time = profile.host_time_ms / 1e3 - self._increment_global_time(elapsed_host_time) - self.execution_start = not self.execution_start - return self.current_time - - def _get_kernel_time( - self, prof_averages: torch.autograd.profiler_util.EventList - ) -> float: - """ - Arguments: - prof_averages: Output of self.prof.key_averages() - Returns: - time_value: Elapsed CUDA time in seconds. - """ - elapsed_cuda_time = 0 - has_cuda_event = False - for event in prof_averages: - if event.device_type != DeviceType.CUDA: - continue - has_cuda_event = True - # Re: torch profiler API changes in https://github.com/pytorch/pytorch/pull/123247 - elapsed_cuda_time = ( - elapsed_cuda_time + event.self_device_time_total - if hasattr(event, "self_device_time_total") - else event.self_cuda_time_total - ) - assert has_cuda_event, "No CUDA events found" - return elapsed_cuda_time / 1e6 - - def _increment_global_time(self, elapsed_time: float) -> None: - self.current_time += elapsed_time + # Set the fd object for fusion profiling. + # fd is returned by setup() for host benchmarking. + def set_fd(self, fd): + assert isinstance(self._timer, FusionProfileTimer) + self._timer.set_fd(fd) - def cleanup(self) -> None: - """ - Stops a running torchprofiler instance if found. - """ - try: - self.prof.stop() - except AssertionError: - pass + def cleanup(self): + self._timer.cleanup() def set_metrics( self, @@ -374,7 +301,7 @@ def setup(): # The host_benchmark_fn uses the `fd` object returned from setup function. def host_benchmark_fn(inputs, fd): # Set the fd variable used to query the profile object - nvf_benchmark.fd = fd + nvf_benchmark.set_fd(fd) return fd.execute(inputs, profile=True) benchmark_fn = benchmark_fn if benchmark_fn is not None else host_benchmark_fn diff --git a/benchmarks/python/test_cross_entropy_loss.py b/benchmarks/python/test_cross_entropy_loss.py index 88bb2101a90..6d9124c56e3 100644 --- a/benchmarks/python/test_cross_entropy_loss.py +++ b/benchmarks/python/test_cross_entropy_loss.py @@ -20,7 +20,7 @@ @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) -def test_rope_fwd_benchmark( +def test_cross_entropy_fwd_benchmark( benchmark, variation: str, executor: str, @@ -52,7 +52,7 @@ def fwd_call(inp): @pytest.mark.parametrize( "executor", ["eager", "torchcompile", "thunder", "thunder-torchcompile"] ) -def test_rope_bwd_benchmark( +def test_cross_entropy_bwd_benchmark( benchmark, variation: str, executor: str, diff --git a/benchmarks/python/test_matmul.py b/benchmarks/python/test_matmul.py index 669b032bd3f..f7886178198 100644 --- a/benchmarks/python/test_matmul.py +++ b/benchmarks/python/test_matmul.py @@ -41,6 +41,9 @@ def test_matmul_baseline_benchmark( ): m, n, k, layout = config + if (m * k + n * k + m * n) * 2 > 20 * (2**30): + pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction @@ -75,6 +78,9 @@ def test_matmul_nvf_benchmark( ): m, n, k, layout = config + if (m * k + n * k + m * n) * 2 > 20 * (2**30): + pytest.skip("Case takes more than 20GiB. Skipping to avoid OOM") + torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = half_reduction torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = half_reduction diff --git a/csrc/alias_analysis.cpp b/csrc/alias_analysis.cpp index 8bca891e4c4..8fd5da9da77 100644 --- a/csrc/alias_analysis.cpp +++ b/csrc/alias_analysis.cpp @@ -239,10 +239,6 @@ void AliasFinder::handle(const ViewOp* view) { } void AliasFinder::handle(const LoadStoreOp* set) { - if (isResharding(set)) { - return; - } - TensorView* in = dynamic_cast(set->in()); if (in == nullptr) { return; diff --git a/csrc/codegen.cpp b/csrc/codegen.cpp index 92e47b3d01e..3bdaa49aea8 100644 --- a/csrc/codegen.cpp +++ b/csrc/codegen.cpp @@ -1283,7 +1283,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { tidy->value().as() + tidz->value().as(); NVF_ERROR( num_threads == 128, - "Expected 128 threads in LoadWarp, but found ", + "Expected 128 threads in AsyncWarp, but found ", num_threads); NVF_ERROR(pdim_map.hasWarpSpecialization()); ss << "dim3(" << genInlineOrOne(tidx) << ", " << genInlineOrOne(tidy) @@ -3557,7 +3557,7 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "block_sync::sync();\n"; } else if (isAligned()) { indent() << "__syncthreads();\n"; - } else if (sync->isLoadWarpSync()) { + } else if (sync->isAsyncWarpSync()) { ArgumentBuilder template_args; template_args.arg(isAligned()); ArgumentBuilder func_args; @@ -3749,6 +3749,10 @@ class CudaKernelGenerator : private kir::ConstIrVisitor { indent() << "NVFUSER_UPDATE_MAGIC_ZERO;\n"; } + void handle(const kir::Continue* cont) final { + indent() << "continue;\n"; + } + void handle(const kir::Return* ret) final { indent() << "return;\n"; } diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index eb230c38321..db77253bed7 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -75,7 +75,7 @@ namespace { bool isComputeWarp(TensorView* consumer, IterDomain* id_in_consumer) { // TODO: This function can not find all the expressions in the compute // warp. For example, if we have: - // if (load warp) { + // if (async warp) { // T1 = T0; // } else { // T2 = T1; diff --git a/csrc/device_lower/lower2device.h b/csrc/device_lower/lower2device.h index 3e60d13621c..883c53f1ba9 100644 --- a/csrc/device_lower/lower2device.h +++ b/csrc/device_lower/lower2device.h @@ -227,12 +227,12 @@ class GpuLower : public NonCopyable { return profile_; } - std::unordered_map& ldstMBarrierMap() { - return ldst_mbarrier_map_; + std::unordered_map& mbarrierMap() { + return mbarrier_map_; } - const std::unordered_map& ldstMBarrierMap() const { - return ldst_mbarrier_map_; + const std::unordered_map& mbarrierMap() const { + return mbarrier_map_; } bool isNvFuserZeroEnabled() { @@ -432,8 +432,8 @@ class GpuLower : public NonCopyable { // precomputed values std::vector all_known_vals_; - // Keep track of the mbarrier used for each load/store operation - std::unordered_map ldst_mbarrier_map_; + // Keep track of the mbarrier used for each load/store and blackwell utcmma + std::unordered_map mbarrier_map_; // Information about tensor memory usage TensorMemoryInfo tmem_info_; diff --git a/csrc/device_lower/pass/allocation.cpp b/csrc/device_lower/pass/allocation.cpp index 47e841b92ef..8a869073be3 100644 --- a/csrc/device_lower/pass/allocation.cpp +++ b/csrc/device_lower/pass/allocation.cpp @@ -1347,7 +1347,8 @@ class AllocationInserter : public kir::ExprMutator { // circular buffering pass. // * Assume that the tma load is in ComputeWarp if it is not circular // buffered. - if (ir_utils::isCpAsyncBulkLoad(expr) && circular_buffer_depth == 1) { + if ((ir_utils::isCpAsyncBulkLoad(expr) && circular_buffer_depth == 1) || + (expr->isA() && expr->as()->isBlackwell())) { // create and allocate a memory barrier TensorView* mbarrier = TensorViewBuilder() .shape(std::vector{}) @@ -1359,8 +1360,9 @@ class AllocationInserter : public kir::ExprMutator { mbarrier, simplifyExpr(SimplifyingIrBuilder::maybeCastExpr( DataType::UInt32, - lower_utils::getNumThreadsInTensorView( - expr->output(0)->as())))); + expr->isA() ? expr->fusion()->oneVal() + : lower_utils::getNumThreadsInTensorView( + expr->output(0)->as())))); auto sync_init = IrBuilder::create( /*war_sync=*/false, /*optional_compute_or_load_sync=*/true); auto mbarrier_inval = @@ -1376,7 +1378,7 @@ class AllocationInserter : public kir::ExprMutator { registerInsertBefore(expr, sync_init, expr_scope); registerInsertAfter(expr, mbarrier_inval, expr_scope); registerInsertAfter(expr, sync_inval, expr_scope); - GpuLower::current()->ldstMBarrierMap()[expr] = mbarrier; + GpuLower::current()->mbarrierMap()[expr] = mbarrier; } } @@ -1484,7 +1486,7 @@ class AllocationInserter : public kir::ExprMutator { continue; } // Map LoadStoreOp expression to ir nodes created in this pass - GpuLower::current()->ldstMBarrierMap()[tv->definition()] = mbarrier; + GpuLower::current()->mbarrierMap()[tv->definition()] = mbarrier; } } } diff --git a/csrc/device_lower/pass/circular_buffer.cpp b/csrc/device_lower/pass/circular_buffer.cpp index 83626d653d1..02bb05050ac 100644 --- a/csrc/device_lower/pass/circular_buffer.cpp +++ b/csrc/device_lower/pass/circular_buffer.cpp @@ -106,7 +106,7 @@ class CircularBufferLoopCloner : public kir::IrVisitor { SimplifyingIrBuilder::create(opt.prefetch, DataType::Index)); break; } - case CircularBufferLoopStage::LoadWarp: + case CircularBufferLoopStage::AsyncWarp: case CircularBufferLoopStage::ComputeWarp: { break; } @@ -618,7 +618,7 @@ class CloneTmaCircularBufferLoopAndInsertSync return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); for (auto tv : ir_utils::filterByType(expr->inputs())) { // short-circuit: The TensorView input for current expression is not @@ -657,7 +657,7 @@ class CloneTmaCircularBufferLoopAndInsertSync !hasCircularBufferLoad()) { return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); for (auto tv : ir_utils::filterByType(expr->outputs())) { // short-circuit: The current expression is not a circular buffer load, so @@ -694,7 +694,7 @@ class CloneTmaCircularBufferLoopAndInsertSync return; } - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); // remove expr from war_mbarriers_to_uses_ auto input_tvs = ir_utils::filterByType(expr->inputs()); for (auto tv : input_tvs) { @@ -799,7 +799,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // expressions std::unordered_map getAllMbarriersToWait() { - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); std::unordered_map wait_exprs; for (auto tv : circular_buffer_load_tvs_) { LoadStoreOp* ldst = dynamic_cast(tv->definition()); @@ -820,7 +820,7 @@ class CloneTmaCircularBufferLoopAndInsertSync // buffer tensor tracked by this mbarrier. std::unordered_map> getAllWarMbarriersToUses() { - const auto& ldst_mbarrier_map = GpuLower::current()->ldstMBarrierMap(); + const auto& ldst_mbarrier_map = GpuLower::current()->mbarrierMap(); std::unordered_map> mbarrier_to_uses; auto exprs = ir_utils::flattenScopedExprs(circular_buffer_loop_->body().exprs()); @@ -949,7 +949,7 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(ldst != nullptr); // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create(all_mbarriers, currentLoadStage()); @@ -985,7 +985,7 @@ class CloneTmaCircularBufferLoopAndInsertSync NVF_ERROR(ldst != nullptr); // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create( all_mbarriers, currentComputeStage()); @@ -1007,7 +1007,7 @@ class CloneTmaCircularBufferLoopAndInsertSync .stage; // Get mbarrier for this circular buffer stage. - TensorView* all_mbarriers = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* all_mbarriers = GpuLower::current()->mbarrierMap().at(ldst); kir::TensorIndex* stage_mbarrier = IrBuilder::create( all_mbarriers, SimplifyingIrBuilder::addExpr(currentLoadStage(), stage_depth)); @@ -1144,13 +1144,112 @@ class IsCircularBufferLoadLoop : public kir::IrVisitor { bool result_ = false; }; +namespace { + +bool isWarpSpecialized(ForLoop* loop) { + return std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); +} + +} // namespace + // Traverse lowered loop-nests and find all circular buffer loops and // associated load expressions. class CircularBufferLoopNestInspector : private kir::IrVisitor { public: - static InsertionInfo run(const std::vector& exprs) { + static std::pair run( + const std::vector& exprs) { CircularBufferLoopNestInspector inspector(exprs); - return inspector.insertion_info_; + + // InsertionInfo holds all circular buffer for-loops. Split it into warp + // specialized and pipeline circular buffers. Enforce that we can only nest + // pipeline circular buffering inside of warp-specialization. + + // Get WarpSpecialized InsertionInfo + InsertionInfo ws_info; + int64_t inner_most_ws_position = -1; + for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) { + if (!isWarpSpecialized(cb_loop)) { + continue; + } + ws_info[cb_loop] = cb_exprs; + inner_most_ws_position = std::max( + inner_most_ws_position, inspector.loop_position_.at(cb_loop)); + } + + // WarpSpecialized circular buffering pads the thread block size by 128 + // threads. This is to support register sharing, which shares registers from + // four warps to another four warps. Thus, we can have four warps running + // concurrently in AsyncWarp. Each warp can launch an asynchronous operation + // with mbarrier completion mechanism such as TMA Load and Blackwell UTCMMA. + // + // if (Select AsyncWarp) { + // if (Select Warp 0 AND elect-sync()) { + // do-something + // } else if (Select Warp 1 AND elect-sync()) { + // do-something + // } else if (Select Warp 2 AND elect-sync()) { + // do-something + // } else if (Select Warp 3 AND elect-sync()) { + // do-something + // } + // } + NVF_ERROR( + ws_info.size() <= 4, + "At most four for-loops can run concurrently inside the AsyncWarp.\n", + "Detected ", + ws_info.size(), + " WarpSpecialized for-loops."); + + // Get Pipeline InsertionInfo + InsertionInfo pipeline_info; + for (auto&& [cb_loop, cb_exprs] : inspector.insertion_info_) { + if (isWarpSpecialized(cb_loop)) { + continue; + } + + // An example of WarpSpecialized circular buffer nested in Pipeline + // circular buffer. + // * Register sharing would fail because of the return in the AsyncLoop. + // * This scenario is not actively tested, so prohibit it until a valid + // use-case occurs. + // + // warp-specialized mbarrier init + // for (prologue) { + // load something for Prologue + // } + // + // for (main) { + // load something for Main + // if (AsyncWarp) { + // launch async + // maybe return for register sharing + // } else { + // compute something for ComputeWarp + // } + // compute something for Main + // } + // + // for (epilogue) { + // if (AsyncWarp) { + // launch async + // maybe return for register sharing + // } else { + // compute something + // } + // compute something for Epilogue + // } + // warp-specialized mbarrier inval + NVF_ERROR( + inspector.loop_position_.at(cb_loop) > inner_most_ws_position, + "Warp Specialization cannot be nested in Pipeline circular buffering!"); + pipeline_info[cb_loop] = cb_exprs; + } + + return {ws_info, pipeline_info}; } private: @@ -1186,6 +1285,10 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor { validateCircularBufferLoop(circular_buffer_loop); + auto cb_loop_it = + std::find(for_loops_.begin(), for_loops_.end(), circular_buffer_loop); + loop_position_[circular_buffer_loop] = + std::distance(for_loops_.begin(), cb_loop_it); insertion_info_[circular_buffer_loop].push_back(expr); } @@ -1211,6 +1314,8 @@ class CircularBufferLoopNestInspector : private kir::IrVisitor { loop->toString()); } + // Map circular buffer loop to its position in the for_loop_ stack. + std::unordered_map loop_position_; InsertionInfo insertion_info_; }; @@ -1229,10 +1334,51 @@ void getAllocInTrivialLoop(ForLoop* fl, std::unordered_set& output) { } } +// Create something like below: +// for (int i = 0; i < prefetch + 1; ++i) { +// mbarrier::arrive(mbarrier0[stage + i]]); +// mbarrier::arrive(mbarrier1[stage + i]); +// ... +// } +// where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i. +// +// This is needed because we prefetch data in circular buffering, and we +// need to make sure the initial prefetches are not blocked by the +// non-existing WAR hazards. +ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) { + const auto& opt = + GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( + circular_buffer_loop->iter_domain()); + auto circular_buffer_tvs = + GpuLower::current()->circularBufferInfo().getCircularBufferTvs( + circular_buffer_loop->iter_domain()); + VectorOfUniqueEntries mbarriers; + for (auto tv : circular_buffer_tvs) { + auto ldst = dynamic_cast(tv->definition()); + NVF_ERROR(ldst != nullptr); + auto it = GpuLower::current()->mbarrierMap().find(ldst); + if (it == GpuLower::current()->mbarrierMap().end()) { + continue; + } + mbarriers.pushBack(it->second); + } + auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1); + for (auto mbarrier : mbarriers) { + auto mbarrier_to_arrive = IrBuilder::create( + mbarrier, + SimplifyingIrBuilder::addExpr( + prefetch_loop->indexOrStartIfTrivial(), opt.stage)); + auto prefetch = IrBuilder::create( + /*state=*/nullptr, mbarrier_to_arrive); + prefetch_loop->body().push_back(prefetch); + } + return prefetch_loop; +} + } // namespace -// Apply circular buffering transformations -class CircularBufferInserter : private kir::ExprMutator { +// Apply warp specialized circular buffering transformations +class WarpSpecializedCircularBufferInserter : private kir::ExprMutator { public: // When there exist multiple circular buffer loops, apply // transformations to inner-most loops first. A single ExprMutator @@ -1242,14 +1388,15 @@ class CircularBufferInserter : private kir::ExprMutator { InsertionInfo insertion_info) { std::vector inserted_exprs = exprs; while (!insertion_info.empty()) { - CircularBufferInserter inserter(inserted_exprs, insertion_info); + WarpSpecializedCircularBufferInserter inserter( + inserted_exprs, insertion_info); inserted_exprs = inserter.exprs_; } return inserted_exprs; } private: - CircularBufferInserter( + WarpSpecializedCircularBufferInserter( const std::vector& exprs, InsertionInfo& insertion_info) : insertion_info_(insertion_info) { @@ -1275,143 +1422,24 @@ class CircularBufferInserter : private kir::ExprMutator { return; } - auto has_cp_async_bulk = std::any_of( - it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); - bool use_warp_specialization = std::holds_alternative( GpuLower::current() ->circularBufferInfo() .getCircularBufferOptionsFor(loop->iter_domain()) .type); - if (use_warp_specialization) { - NVF_ERROR( - std::all_of( - it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), - "In order to use warp specialization, all buffers must be loaded by TMA"); - int64_t insertion_position = - GpuLower::current() - ->circularBufferInfo() - .getCircularBufferInsertionPosition(loop->iter_domain()); - insertTmaWarpSpecialized(loop, it->second, insertion_position); - } else if (has_cp_async_bulk) { - insertTmaPipelined(loop, it->second); - } else { - insert(loop, it->second); - } - processed_loop_ = loop; - insertion_info_.erase(loop); - } - - bool hasPrefetch(ForLoop* circular_buffer_loop) { - int64_t prefetch_distance = + NVF_ERROR(use_warp_specialization); + NVF_ERROR( + std::all_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk), + "In order to use warp specialization, all buffers must be loaded by TMA"); + int64_t insertion_position = GpuLower::current() ->circularBufferInfo() - .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) - .prefetch; - return prefetch_distance > 0; - } + .getCircularBufferInsertionPosition(loop->iter_domain()); + insertTmaWarpSpecialized(loop, it->second, insertion_position); - // Create something like below: - // for (int i = 0; i < prefetch + 1; ++i) { - // mbarrier::arrive(mbarrier0[stage + i]]); - // mbarrier::arrive(mbarrier1[stage + i]); - // ... - // } - // where mbarrierX[stage + i] is the X-th WAR mbarrier for stage i. - // - // This is needed because we prefetch data in circular buffering, and we - // need to make sure the initial prefetches are not blocked by the - // non-existing WAR hazards. - ForLoop* createArrivesForWar(ForLoop* circular_buffer_loop) { - const auto& opt = - GpuLower::current()->circularBufferInfo().getCircularBufferOptionsFor( - circular_buffer_loop->iter_domain()); - auto circular_buffer_tvs = - GpuLower::current()->circularBufferInfo().getCircularBufferTvs( - circular_buffer_loop->iter_domain()); - VectorOfUniqueEntries mbarriers; - for (auto tv : circular_buffer_tvs) { - auto ldst = dynamic_cast(tv->definition()); - NVF_ERROR(ldst != nullptr); - auto it = GpuLower::current()->ldstMBarrierMap().find(ldst); - if (it == GpuLower::current()->ldstMBarrierMap().end()) { - continue; - } - mbarriers.pushBack(it->second); - } - auto prefetch_loop = ir_utils::createRangeLoop(opt.prefetch + 1); - for (auto mbarrier : mbarriers) { - auto mbarrier_to_arrive = IrBuilder::create( - mbarrier, - SimplifyingIrBuilder::addExpr( - prefetch_loop->indexOrStartIfTrivial(), opt.stage)); - auto prefetch = IrBuilder::create( - /*state=*/nullptr, mbarrier_to_arrive); - prefetch_loop->body().push_back(prefetch); - } - return prefetch_loop; - } - - static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) { - return GpuLower::current() - ->circularBufferInfo() - .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) - .usesMBarrierForWAR(); - } - - void insertTmaPipelined( - ForLoop* circular_buffer_loop, - const std::vector& loads) { - // Arrive on the WAR mbarriers to let the prefetching start. - if (usesMBarrierForWAR(circular_buffer_loop)) { - auto prefetch_loop = createArrivesForWar(circular_buffer_loop); - registerInsertBefore(circular_buffer_loop, prefetch_loop); - } - - // Prologue loop: - // - launch only - // - arrive_expect_tx and tma load operations - if (hasPrefetch(circular_buffer_loop)) { - // If there is no prefetch, then we don't need a prologue loop. - ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Prolog, - /*insertion_position=*/1); - registerInsertBefore(circular_buffer_loop, prologue_loop); - } - - // Main loop: - // - Launch and wait - // - arrive_expect_tx, tma load operations, and mbarrier_wait - ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Main, - /*insertion_position=*/1); - registerReplace(circular_buffer_loop, main_loop); - - if (!hasPrefetch(circular_buffer_loop)) { - // If there is no prefetch, then we don't need a epilogue loop. - return; - } - - // We can use exclude argument in - // CloneTmaCircularBufferLoopAndInsertSync clone to avoid - // duplicating allocations if main loop is trivial. - std::unordered_set expressions_allocated_in_main_loop; - getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop); - - // Epilogue loop: - // - wait only - // - mbarrier_wait - ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( - circular_buffer_loop, - loads, - CircularBufferLoopStage::Epilog, - /*insertion_position=*/1, - expressions_allocated_in_main_loop); - registerInsertAfter(circular_buffer_loop, epilogue_loop); + processed_loop_ = loop; + insertion_info_.erase(loop); } void insertTmaWarpSpecialized( @@ -1462,24 +1490,24 @@ class CircularBufferInserter : private kir::ExprMutator { .num_registers.value(); GpuLower::current()->decIncRegisterUsage() = std::make_pair(decrease_num_registers, increase_num_registers); - // Decrease registers in load warp group - kir::SetMaxNReg* dec_reg_load_warp = IrBuilder::create( + // Decrease registers in async warp group + kir::SetMaxNReg* dec_reg_async_warp = IrBuilder::create( IrBuilder::create(decrease_num_registers, DataType::Index), /*increase_registers=*/false); - warp_dispatch_ite->thenBody().push_back(dec_reg_load_warp); + warp_dispatch_ite->thenBody().push_back(dec_reg_async_warp); // Increase registers in compute warp group - kir::SetMaxNReg* inc_reg_load_warp = IrBuilder::create( + kir::SetMaxNReg* inc_reg_async_warp = IrBuilder::create( IrBuilder::create(increase_num_registers, DataType::Index), /*increase_registers*/ true); - warp_dispatch_ite->elseBody().push_back(inc_reg_load_warp); + warp_dispatch_ite->elseBody().push_back(inc_reg_async_warp); } // Load loop: ForLoop* load_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( circular_buffer_loop, loads, - CircularBufferLoopStage::LoadWarp, + CircularBufferLoopStage::AsyncWarp, insertion_position); warp_dispatch_ite->thenBody().push_back(load_loop); @@ -1505,6 +1533,145 @@ class CircularBufferInserter : private kir::ExprMutator { registerReplace(circular_buffer_loop, warp_dispatch_ite); } + private: + InsertionInfo& insertion_info_; + ForLoop* processed_loop_ = nullptr; +}; + +// Apply pipeline circular buffering transformations +class PipelineCircularBufferInserter : private kir::ExprMutator { + public: + // When there exist multiple circular buffer loops, apply + // transformations to inner-most loops first. A single ExprMutator + // pass can only process one loop. + static std::vector run( + const std::vector& exprs, + InsertionInfo insertion_info) { + std::vector inserted_exprs = exprs; + while (!insertion_info.empty()) { + PipelineCircularBufferInserter inserter(inserted_exprs, insertion_info); + inserted_exprs = inserter.exprs_; + } + return inserted_exprs; + } + + private: + PipelineCircularBufferInserter( + const std::vector& exprs, + InsertionInfo& insertion_info) + : insertion_info_(insertion_info) { + size_t num_circular_buffer_loops = insertion_info.size(); + traverseAndInsert(exprs); + NVF_ERROR(processed_loop_ != nullptr); + NVF_ERROR(insertion_info.size() == num_circular_buffer_loops - 1); + } + + using kir::ExprMutator::handle; + + void handle(ForLoop* loop) final { + kir::ExprMutator::handle(loop); + + // If another loop is already taken care of, no more loop should + // be done in the same pass + if (processed_loop_ != nullptr) { + return; + } + + auto it = insertion_info_.find(loop); + if (it == insertion_info_.end()) { + return; + } + + bool use_warp_specialization = std::holds_alternative( + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(loop->iter_domain()) + .type); + NVF_ERROR(!use_warp_specialization); + + auto has_cp_async_bulk = std::any_of( + it->second.begin(), it->second.end(), ir_utils::isCpAsyncBulk); + if (has_cp_async_bulk) { + insertTmaPipelined(loop, it->second); + } else { + insert(loop, it->second); + } + + processed_loop_ = loop; + insertion_info_.erase(loop); + } + + bool hasPrefetch(ForLoop* circular_buffer_loop) { + int64_t prefetch_distance = + GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) + .prefetch; + return prefetch_distance > 0; + } + + static bool usesMBarrierForWAR(ForLoop* circular_buffer_loop) { + return GpuLower::current() + ->circularBufferInfo() + .getCircularBufferOptionsFor(circular_buffer_loop->iter_domain()) + .usesMBarrierForWAR(); + } + + void insertTmaPipelined( + ForLoop* circular_buffer_loop, + const std::vector& loads) { + // Arrive on the WAR mbarriers to let the prefetching start. + if (usesMBarrierForWAR(circular_buffer_loop)) { + auto prefetch_loop = createArrivesForWar(circular_buffer_loop); + registerInsertBefore(circular_buffer_loop, prefetch_loop); + } + + // Prologue loop: + // - launch only + // - arrive_expect_tx and tma load operations + if (hasPrefetch(circular_buffer_loop)) { + // If there is no prefetch, then we don't need a prologue loop. + ForLoop* prologue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Prolog, + /*insertion_position=*/1); + registerInsertBefore(circular_buffer_loop, prologue_loop); + } + + // Main loop: + // - Launch and wait + // - arrive_expect_tx, tma load operations, and mbarrier_wait + ForLoop* main_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Main, + /*insertion_position=*/1); + registerReplace(circular_buffer_loop, main_loop); + + if (!hasPrefetch(circular_buffer_loop)) { + // If there is no prefetch, then we don't need a epilogue loop. + return; + } + + // We can use exclude argument in + // CloneTmaCircularBufferLoopAndInsertSync clone to avoid + // duplicating allocations if main loop is trivial. + std::unordered_set expressions_allocated_in_main_loop; + getAllocInTrivialLoop(main_loop, expressions_allocated_in_main_loop); + + // Epilogue loop: + // - wait only + // - mbarrier_wait + ForLoop* epilogue_loop = CloneTmaCircularBufferLoopAndInsertSync::clone( + circular_buffer_loop, + loads, + CircularBufferLoopStage::Epilog, + /*insertion_position=*/1, + expressions_allocated_in_main_loop); + registerInsertAfter(circular_buffer_loop, epilogue_loop); + } + void insert(ForLoop* circular_buffer_loop, const std::vector& loads) { NVF_ERROR( !usesMBarrierForWAR(circular_buffer_loop), @@ -1728,8 +1895,15 @@ kir::TensorIndex* TmaCircularBufferInfo::getTensorIndex(const Expr* expr) { } std::vector CircularBufferPass::run(const std::vector& exprs) { - InsertionInfo insertion_info = CircularBufferLoopNestInspector::run(exprs); - return CircularBufferInserter::run(exprs, insertion_info); + auto&& [ws_insertion_info, pipeline_insertion_info] = + CircularBufferLoopNestInspector::run(exprs); + // Process circular buffer for-loops from inner to outer-most. + // Pipeline must come before WarpSpecialized. We cannot nest WarpSpecialized + // inside of Pipeline circular buffering. + std::vector result_exprs = + PipelineCircularBufferInserter::run(exprs, pipeline_insertion_info); + return WarpSpecializedCircularBufferInserter::run( + result_exprs, ws_insertion_info); } } // namespace nvfuser diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 25605ddd0d4..ba0082e9dbc 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -1522,7 +1522,7 @@ void IndexLowering::handleCpAsyncBulkLoad(const LoadStoreOp* ldst) { GpuLower::current()->propagateExprInfo(ldst, back()); } else { - TensorView* mbarrier = GpuLower::current()->ldstMBarrierMap().at(ldst); + TensorView* mbarrier = GpuLower::current()->mbarrierMap().at(ldst); Val* mbarrier_index = lower_utils::u32IndexScalarSmemTv(mbarrier); // gmem indexing and expect_bytes for mbarrier @@ -2697,6 +2697,14 @@ void IndexLowering::handle(const MmaOp* mma) { auto mma_indexed = IrBuilder::create(out, a, b, mma->init(), mma->macro()); pushBack(mma_indexed); + if (mma->isBlackwell()) { + pushBack(IrBuilder::create( + "tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64", + std::vector{}, + std::vector{lower_utils::u32IndexScalarSmemTv( + GpuLower::current()->mbarrierMap().at(mma))}, + kir::Asm::Options{/*volatile=*/true})); + } GpuLower::current()->propagateExprInfo(mma, back()); } @@ -2805,6 +2813,11 @@ void IndexLowering::handle(const kir::SetMaxNReg* maxnreg) { pushBack(const_cast(maxnreg)); // NOLINT } +void IndexLowering::handle(const kir::Continue* cont) { + // TODO(kir): remove the need for const_cast + pushBack(const_cast(cont)); // NOLINT +} + void IndexLowering::handle(const kir::Return* ret) { // TODO(kir): remove the need for const_cast pushBack(const_cast(ret)); // NOLINT diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index 7e1699821c2..2bc0e973709 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -77,6 +77,7 @@ class IndexLowering : private OptOutConstDispatch { void handle(const kir::FenceAsyncProxy*) final; void handle(const kir::WgMmaFence*) final; void handle(const kir::SetMaxNReg*) final; + void handle(const kir::Continue*) final; void handle(const kir::Return*) final; void handle(const kir::MBarrierInit*) final; void handle(const kir::MBarrierInvalidate*) final; diff --git a/csrc/device_lower/pass/inline_ptx.cpp b/csrc/device_lower/pass/inline_ptx.cpp index 704475a0152..7ac4aebc577 100644 --- a/csrc/device_lower/pass/inline_ptx.cpp +++ b/csrc/device_lower/pass/inline_ptx.cpp @@ -378,7 +378,7 @@ class LowerToInlinePtx : public kir::ExprMutator { Val* enable_input_d = getUseInputAcc(mma); // Do MMA - registerInsertBefore( + registerReplace( mma, IrBuilder::create( "tcgen05.mma.cta_group::1.kind::f16", diff --git a/csrc/device_lower/pass/insert_syncs.cpp b/csrc/device_lower/pass/insert_syncs.cpp index 3a6b11c7ca1..09fa9df9666 100644 --- a/csrc/device_lower/pass/insert_syncs.cpp +++ b/csrc/device_lower/pass/insert_syncs.cpp @@ -20,10 +20,10 @@ namespace nvfuser { namespace { -// Determine if any for loop is a LoadWarp circular buffering stage -bool isWithinLoadWarp(const std::vector for_loops) { +// Determine if any for loop is a AsyncWarp circular buffering stage +bool isWithinAsyncWarp(const std::vector for_loops) { return std::any_of(for_loops.begin(), for_loops.end(), [](ForLoop* fl) { - return fl->circularBufferLoopStage() == CircularBufferLoopStage::LoadWarp; + return fl->circularBufferLoopStage() == CircularBufferLoopStage::AsyncWarp; }); } @@ -36,16 +36,16 @@ bool isWithinComputeWarp(const std::vector for_loops) { } // Return true if any for loop is ComputeWarp. -// Return false if any for loop is LoadWarp. +// Return false if any for loop is AsyncWarp. // Return std:nullopt if none of the for loops are a warp specialized stage. std::optional isOptionalComputeSync( const std::vector for_loops) { - bool contains_load_warp = isWithinLoadWarp(for_loops); + bool contains_async_warp = isWithinAsyncWarp(for_loops); bool contains_compute_warp = isWithinComputeWarp(for_loops); NVF_ERROR( - !contains_load_warp || !contains_compute_warp, - "The list of for-loops contains both LoadWarp and ComputeWarp stages."); - if (isWithinLoadWarp(for_loops)) { + !contains_async_warp || !contains_compute_warp, + "The list of for-loops contains both AsyncWarp and ComputeWarp stages."); + if (isWithinAsyncWarp(for_loops)) { return false; } else if (isWithinComputeWarp(for_loops)) { return true; @@ -54,6 +54,30 @@ std::optional isOptionalComputeSync( } } +// Commit a series of operations to an async group. +// Create wgmma.fence for AsyncOpType::WgMma +// Otherwise, create fence.proxy.async +Expr* getAsyncFence(AsyncOpType async_type) { + if (async_type == AsyncOpType::WgMma) { + return IrBuilder::create(); + } + return IrBuilder::create(); +} + +// Commit a series of operations to an async group. +// Create wgmma.commit_group.sync.aligned for AsyncOpType::WgMma +// Create cpAsyncBulkCommitGroup for AsyncOpType::CpAsyncBulk +Expr* getAsyncCommit(AsyncOpType async_type) { + return IrBuilder::create(async_type); +} + +// Wait for a number of async groups to finish. +// Create wgmma.wait_group.sync.aligned for AsyncOpType::WgMma +// Create cpAsyncBulkWaitGroup for AsyncOpType::CpAsyncBulk +Expr* getAsyncWait(AsyncOpType async_type, int64_t keep_stages = 0) { + return IrBuilder::create(async_type, keep_stages); +} + // Tensor memory is similar to shared memory because they are both // shared between threads in a block. In that sense, we can consider // tensor memory as special type of shared memory. In this file, we use @@ -449,8 +473,7 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // except when these are accumulator register accesses across multiple // wgmma.mma_async instructions of the same shape. In the latter case, // an ordering guarantee is provided by default. - auto wgmma_fence = IrBuilder::create(); - registerInsertBefore(expr, wgmma_fence, scope); + registerInsertBefore(expr, getAsyncFence(AsyncOpType::WgMma), scope); if (!lower_utils::allMmaInputsGuardedByMBarrier(mma)) { // fence.proxy.async makes sure that writes to operands in the generic // proxy are visible to the async proxy @@ -463,19 +486,13 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // async mma pipeline has not been flushed yet. flush_async_mma_pipeline_ = false; } else if (mma->isBlackwell()) { - // TODO: This is clearly a wrong way to sync, but as an intermediate - // step to enable incremental development, we use nanosleep to sync the - // mma. We should replace this with a correct sync method. - registerInsertBefore(expr, IrBuilder::create()); registerInsertAfter( expr, - IrBuilder::create( - "nanosleep.u32", - std::vector{}, - std::vector{ - IrBuilder::create(4000000000, DataType::UInt32)}, - kir::Asm::Options{/*volatile=*/true})); - registerInsertAfter(expr, IrBuilder::create()); + IrBuilder::create( + IrBuilder::create( + GpuLower::current()->mbarrierMap().at(expr), + expr->fusion()->zeroVal()), + expr->fusion()->zeroVal(DataType::UInt32))); } } else if (ir_utils::isCpAsyncBulkStore(expr)) { // Add a fence before TMA store so that writes in the generic proxy is @@ -522,10 +539,11 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { } } for (const auto& [async_type, ops] : input_async_ops) { - auto sync_exprs = lower_utils::getSyncExprs( - async_type, - /*keep_stages=*/0, - /*requires_commit=*/async_type != AsyncOpType::WgMma); + std::vector sync_exprs; + if (async_type != AsyncOpType::WgMma) { + sync_exprs.push_back(getAsyncCommit(async_type)); + } + sync_exprs.push_back(getAsyncWait(async_type, /*keep_stages=*/0)); for (auto sync_expr : sync_exprs) { insertSyncExpr(ops, expr, sync_expr, nullptr); } @@ -843,8 +861,9 @@ class ReadAfterWriteSyncs : public kir::ExprMutator { // kernel. for (auto expr : async_exprs_writing_fusion_output_) { auto async_type = ir_utils::getAsyncOpType(expr); - auto sync_exprs = - lower_utils::getSyncExprs(async_type, /*keep_stages=*/0); + std::vector sync_exprs{ + getAsyncCommit(async_type), + getAsyncWait(async_type, /*keep_stages=*/0)}; exprs_.insert(exprs_.end(), sync_exprs.begin(), sync_exprs.end()); } @@ -963,7 +982,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { //! Warp Specialization creates an If-Then-Else to separate load and compute //! operations. Therefore, the async_inputs_in_current_scope_ will not contain //! the async inputs for the corresponding async expression. Track async - //! inputs separately when we encounter them in load warp. + //! inputs separately when we encounter them in async warp. std::unordered_set warp_specialized_async_inputs_in_current_scope_; //! Track async exprs separately when we encounter them in compute warp. @@ -1027,7 +1046,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { return; } - // Gather all async inputs in LoadWarp + // Gather all async inputs in AsyncWarp TensorView* out_tv = ir_utils::getTvOutput(expr); NVF_ERROR(out_tv != nullptr); auto circular_buffer_loop = @@ -1035,7 +1054,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { out_tv, for_loops_); if (circular_buffer_loop != nullptr && circular_buffer_loop->circularBufferLoopStage() == - CircularBufferLoopStage::LoadWarp) { + CircularBufferLoopStage::AsyncWarp) { auto use_async_ops = getUseAsyncOpTypes(out_tv); if (!use_async_ops.empty()) { warp_specialized_async_inputs_in_current_scope_.emplace(out_tv); @@ -1175,10 +1194,10 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Special logic is required for warp specialized circular buffering because // the TMA loads and wgmma ops are separated by an IfThenElse. // kir::ExprMutator traverses the fusion in depth-wise order, so TMA loads in - // the LoadWarp are detected before the wgmma expressions in the ComputeWarp. + // the AsyncWarp are detected before the wgmma expressions in the ComputeWarp. // // This function inserts wgmma.commit_group and wgmma.wait_group expressions - // before the mbarrier::arrive, which allows load warp to launch next TMA + // before the mbarrier::arrive, which allows async warp to launch next TMA // load. First, we commit all the wgmma expressions issued in this iteration // of the for-loop. Then, we wait for some number of wgmma expressions based // on number of circular buffer stages and number of prefetch stages. @@ -1198,7 +1217,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { NVF_ERROR( warp_specialized_async_exprs_to_protect_.empty() || !warp_specialized_async_inputs_in_current_scope_.empty(), - "Expected TMA loads in LoadWarp for WgMma operations were detected in ComputeWarp."); + "Expected TMA loads in AsyncWarp for WgMma operations were detected in ComputeWarp."); // short-circuit: no wgmma expressions to protect in computeWarp. if (warp_specialized_async_exprs_to_protect_.empty()) { @@ -1206,7 +1225,7 @@ class WarAsyncWaitInserter : private kir::ExprMutator { return; } - // Establish all tma loads in LoadWarp are used by WgMma operations in + // Establish all tma loads in AsyncWarp are used by WgMma operations in // ComputeWarp. for (Expr* expr : warp_specialized_async_exprs_to_protect_) { if (ir_utils::isCpAsyncBulkStore(expr)) { @@ -1225,8 +1244,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { active_compute_for_loop_->iter_domain()); int64_t pending_ops = opt.stage - opt.prefetch - 1; - auto sync_exprs = - lower_utils::getSyncExprs(AsyncOpType::WgMma, pending_ops); + std::vector sync_exprs{ + getAsyncCommit(AsyncOpType::WgMma), + getAsyncWait(AsyncOpType::WgMma, /*keep_stages=*/pending_ops)}; size_t num_exprs = for_loop->body().exprs().size(); NVF_ERROR(num_exprs > 1); NVF_ERROR(for_loop->body().exprs().back()->isA()); @@ -1308,7 +1328,9 @@ class WarAsyncWaitInserter : private kir::ExprMutator { // Actually insert these wait expressions. for (auto [type, pending_ops] : types_and_pending_ops_to_protect) { - auto sync_exprs = lower_utils::getSyncExprs(type, pending_ops); + std::vector sync_exprs{ + getAsyncCommit(type), + getAsyncWait(type, /*keep_stages=*/pending_ops)}; NVF_ERROR(!for_loop->body().exprs().empty()); // Default position is last expression in for loop diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 936cfc11513..ac4981ccafa 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -139,6 +139,7 @@ bool isTvOp(const Expr* expr) { TensorConstruct, SelectOp, IndexSelectOp, + IndexPutAccumulateOp, GatherOp, ScatterOp, RNGOp, @@ -2099,21 +2100,6 @@ bool allMmaInputsGuardedByMBarrier(const MmaOp* mma) { ir_utils::isCpAsyncBulkLoad(ir_utils::getTv(mma->inB())->definition()); } -std::vector getSyncExprs( - AsyncOpType async_type, - int64_t keep_stages, - bool requires_commit) { - std::vector sync_exprs; - sync_exprs.reserve(2); - if (requires_commit) { - auto commit = IrBuilder::create(async_type); - sync_exprs.push_back(commit); - } - auto wait = IrBuilder::create(async_type, keep_stages); - sync_exprs.push_back(wait); - return sync_exprs; -} - } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index b45c7e2e3f6..7abc0ab6bfc 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -378,16 +378,6 @@ struct IterDomainDependencySorter { // Check if all the inputs of the given MmaOp is guarded by mbarrier bool allMmaInputsGuardedByMBarrier(const MmaOp* mma); -// Create a list of expressions that will be used to wait for async operations. -// For example, if op_type is AsyncOpType::WgMma, then the returned expressions -// will be: -// wgmma.commit_group.sync.aligned -// wgmma.wait_group.sync.aligned -std::vector getSyncExprs( - AsyncOpType async_type, - int64_t keep_stages = 0, - bool requires_commit = true); - } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/dispatch.h b/csrc/dispatch.h index 218ccd8267a..007287e49f7 100644 --- a/csrc/dispatch.h +++ b/csrc/dispatch.h @@ -84,6 +84,7 @@ class Val; f(TensorConstruct); \ f(SelectOp); \ f(IndexSelectOp); \ + f(IndexPutAccumulateOp); \ f(GatherOp); \ f(ScatterOp); \ f(RNGOp); \ @@ -124,6 +125,7 @@ class Val; f(FenceAsyncProxy); \ f(WgMmaFence); \ f(SetMaxNReg); \ + f(Continue); \ f(Return); \ f(MBarrierInit); \ f(MBarrierInvalidate); \ @@ -158,7 +160,8 @@ class Val; f(Synchronize); \ f(StartCoalescing); \ f(EndCoalescing); \ - f(ShareMemHandles); + f(ShareMemHandles); \ + f(HirAliasSelect); // Forward declarations for all Val and Expr types diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index cdb241b3517..7619cab1ce6 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -276,43 +277,6 @@ std::vector SegmentedGroup:: return merge_candidates; } -void SegmentedGroup::clearTraversalInfo() { - level_ = -1; - merge_with_ = nullptr; - merge_through_ = nullptr; - merged_ = false; -} - -std::vector SegmentedGroup::edgesToVals( - const std::vector& se_v) { - std::vector ret_v; - ret_v.reserve(se_v.size()); - - std::transform( - se_v.cbegin(), - se_v.cend(), - std::back_inserter(ret_v), - [](SegmentedEdge* se) { return se->val; }); - return ret_v; -} - -template -void insertUniquePredicated( - std::vector& v, - const std::vector& e, - PREDICATE pred) { - VectorOfUniqueEntries to_add; - for (auto edge : e) { - to_add.pushBack(edge->val); - } - - std::copy_if( - to_add.vector().begin(), - to_add.vector().end(), - std::back_inserter(v), - [pred](Val* val) { return pred(val); }); -} - // TODO: Reevaluate what's being done in finalize void SegmentedGroup::finalize() { // Make sure all inputs and outputs of the group are now in input and output @@ -604,7 +568,7 @@ void SegmentedFusion::deserialize(const serde::SegmentedFusion* buffer) { // Construct segmented groups first because they are necessary for the // segmented edge's constructor - // NOTE: Use regular for-loop to avoid unused variable ‘idx’ error + // NOTE: Use regular for-loop to avoid unused variable 'idx' error for (size_t idx = 0; idx < buffer->groups()->size(); ++idx) { newGroup(); } @@ -683,25 +647,57 @@ SegmentedEdge* SegmentedFusion::Impl::makeEdge( return edges_.back().get(); } +void SegmentedFusion::removeEdge(SegmentedEdge* edge) { + NVF_ERROR(edge != nullptr, "Edge is nullptr"); + // Validate edge exists in all expected locations + SegmentedGroup* producer = edge->from; + SegmentedGroup* consumer = edge->to; + auto& producer_consumer_edges = producer->consumer_edges; + auto& consumer_producer_edges = consumer->producer_edges; + + // Remove edge from producer's consumer edges + auto producer_edge_it = std::find( + producer_consumer_edges.begin(), producer_consumer_edges.end(), edge); + NVF_ERROR( + producer_edge_it != producer_consumer_edges.end(), + "Edge not found in producer's consumer edges"); + producer_consumer_edges.erase(producer_edge_it); + + // Remove edge from consumer's producer edges + auto consumer_edge_it = std::find( + consumer_producer_edges.begin(), consumer_producer_edges.end(), edge); + NVF_ERROR( + consumer_edge_it != consumer_producer_edges.end(), + "Edge not found in consumer's producer edges"); + consumer_producer_edges.erase(consumer_edge_it); + + // Remove edge from global edge list + auto edge_it = std::find(edges_.begin(), edges_.end(), edge); + NVF_ERROR(edge_it != edges_.end(), "Edge not found in global edge list"); + edges_.erase(edge_it); +} + void SegmentedFusion::Impl::cleanUnused() { std::unordered_set g_used( owning_fusion_->groups().begin(), owning_fusion_->groups().end()); std::unordered_set e_used( owning_fusion_->edges().begin(), owning_fusion_->edges().end()); - groups_.erase( - std::remove_if( - groups_.begin(), - groups_.end(), - [&g_used](auto& g) { return g_used.count(g.get()) == 0; }), - groups_.end()); - + // Remove any edges that are no longer in use edges_.erase( std::remove_if( edges_.begin(), edges_.end(), [&e_used](auto& e) { return e_used.count(e.get()) == 0; }), edges_.end()); + + // Remove any groups that are no longer in use + groups_.erase( + std::remove_if( + groups_.begin(), + groups_.end(), + [&g_used](auto& g) { return g_used.count(g.get()) == 0; }), + groups_.end()); } //! Return mapping from SegmentedGroup to integer id @@ -2016,55 +2012,28 @@ void SegmentCandidateFinder::resetLevels() { } // Disconect group from neighbors, and return edges that were disconnected -std::unordered_set SegmentCandidateFinder::disconnectGroup( - SegmentedGroup* group) { - std::unordered_set removed_edges( +void SegmentCandidateFinder::disconnectGroup(SegmentedGroup* group) { + // Remove producer edges + std::vector producer_edges( group->producer_edges.begin(), group->producer_edges.end()); - - for (auto edge : group->producer_edges) { - auto from = edge->from; - auto& from_edges = from->consumer_edges; - auto from_edge_it = std::find(from_edges.begin(), from_edges.end(), edge); - NVF_ERROR( - from_edge_it != from_edges.end(), "Could not find edge to remove."); - from_edges.erase(from_edge_it); + for (auto edge : producer_edges) { + segmented_fusion_->removeEdge(edge); } - for (auto edge : group->consumer_edges) { - removed_edges.insert(edge); - auto to = edge->to; - auto& to_edges = to->producer_edges; - auto to_edge_it = std::find(to_edges.begin(), to_edges.end(), edge); - NVF_ERROR(to_edge_it != to_edges.end(), "Could not find edge to remove."); - to_edges.erase(to_edge_it); + // Remove consumer edges + std::vector consumer_edges( + group->consumer_edges.begin(), group->consumer_edges.end()); + for (auto edge : consumer_edges) { + segmented_fusion_->removeEdge(edge); } - - group->producer_edges.clear(); - group->consumer_edges.clear(); - - return removed_edges; } void SegmentCandidateFinder::eraseGroups( std::unordered_set& groups_to_erase) { - std::unordered_set edges_to_erase; for (auto group : groups_to_erase) { - auto disconnected_edges = disconnectGroup(group); - edges_to_erase.insert(disconnected_edges.begin(), disconnected_edges.end()); + disconnectGroup(group); } - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [&edges_to_erase](SegmentedEdge* edge) { - if (edges_to_erase.find(edge) != edges_to_erase.end()) { - return true; - }; - return false; - }), - edges().end()); - groups().erase( std::remove_if( groups().begin(), @@ -2078,6 +2047,30 @@ void SegmentCandidateFinder::eraseGroups( groups().end()); } +std::vector SegmentedFusion::getEdgesBetween( + const SegmentedGroup* producer, + const SegmentedGroup* consumer) const { + std::vector edges_between; + + // Look through producer's consumer edges + for (SegmentedEdge* edge : producer->consumer_edges) { + if (edge->to == consumer) { + edges_between.push_back(edge); + } + } + + return edges_between; +} + +void SegmentedFusion::connectGroups( + SegmentedGroup* producer, + SegmentedGroup* consumer, + Val* val) { + SegmentedEdge* new_edge = newEdge(producer, consumer, val); + producer->consumer_edges.push_back(new_edge); + consumer->producer_edges.push_back(new_edge); +} + SegmentedGroup* SegmentCandidateFinder::mergeNodes() { FUSER_PERF_SCOPE("SegmentCandidateFinder::mergeNodes"); SegmentedGroup* last_merged = nullptr; @@ -2093,90 +2086,65 @@ SegmentedGroup* SegmentCandidateFinder::mergeNodes() { // Make the new joined node auto joined_group = segmented_fusion_->newGroup(); + // Merge input and output vals joined_group->input_vals_ = group1->input_vals_.computeUnion(group2->input_vals_); - joined_group->output_vals_ = group1->output_vals_.computeUnion(group2->output_vals_); + // Merge expressions joined_group->exprs_ = group1->exprs_; joined_group->exprs_.insert( joined_group->exprs_.end(), group2->exprs_.begin(), group2->exprs_.end()); + // Get all edges that will connect to the new joined group auto producer_edges = getMergedProducerEdges(group1, group2); - // Connect joined group to resulting neighbors - for (auto edge : producer_edges) { - auto from = edge->from; - auto val = edge->val; + auto consumer_edges = getMergedConsumerEdges(group1, group2); - auto new_edge = segmented_fusion_->newEdge(from, joined_group, val); - joined_group->producer_edges.push_back(new_edge); - from->consumer_edges.push_back(new_edge); + // Connect all producer edges to the new joined group + for (auto edge : producer_edges) { + segmented_fusion_->connectGroups(edge->from, joined_group, edge->val); } - auto consumer_edges = getMergedConsumerEdges(group1, group2); - + // Connect all consumer edges from the new joined group for (auto edge : consumer_edges) { - auto to = edge->to; - auto val = edge->val; - - auto new_edge = segmented_fusion_->newEdge(joined_group, to, val); - joined_group->consumer_edges.push_back(new_edge); - edge->to->producer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(joined_group, edge->to, edge->val); } - // Disconnect the merged groups before deriveSchedulerType, which - // may temporarily inject type cast and can get confused if stale - // edges exist + // Now that all new connections are made, disconnect the old groups, this + // invalidates producer_edges and consumer_edges for (auto merged_group : {group1, group2}) { - auto disconnected_edges = disconnectGroup(merged_group); - clean_up_edges_.insert( - disconnected_edges.begin(), disconnected_edges.end()); + disconnectGroup(merged_group); } + // Set scheduler type for the new group joined_group->setSchedulerType(deriveSchedulerType(joined_group)); - // Need to maintain the group dependency data if it has been intialized - // by previous merging + + // Update group dependency data if initialized if (group_dependency_) { group_dependency_->as()->mergeGroups( group1, group2, joined_group); } + last_merged = joined_group; } to_merge_.clear(); - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [this](SegmentedEdge* edge) { - if (this->clean_up_edges_.find(edge) != - this->clean_up_edges_.end()) { - return true; - }; - return false; - }), - edges().end()); - + // Clean up merged groups groups().erase( std::remove_if( groups().begin(), groups().end(), [this](SegmentedGroup* group) { - if (this->clean_up_groups_.find(group) != - this->clean_up_groups_.end()) { - return true; - }; - return false; + return this->clean_up_groups_.find(group) != + this->clean_up_groups_.end(); }), groups().end()); - clean_up_edges_.clear(); clean_up_groups_.clear(); - return last_merged; } @@ -2187,77 +2155,70 @@ SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( const std::vector& groups_to_merge) { NVF_ERROR( !groups_to_merge.empty(), - "fusion segment :(mergeAllGivenGroups) tried to merge no groups") + "fusion segment :(mergeAllGivenGroups) tried to merge no groups"); + + // The fusion input auxiliary groups should never be merged. + const auto& aux_input_groups = getAuxiliaryInputGroups(); + std::vector aux_groups_to_merge; + std::ranges::copy_if( + groups_to_merge, + std::back_inserter(aux_groups_to_merge), + [&](SegmentedGroup* group) { + return std::ranges::find(aux_input_groups, group) != + aux_input_groups.end(); + }); + NVF_ERROR( + aux_groups_to_merge.empty(), + "Trying to merge auxiliary input groups: ", + toDelimitedString(aux_groups_to_merge)); // Make a set to detect internal edges std::unordered_set group_set( groups_to_merge.begin(), groups_to_merge.end()); - // Sets to de-duplicate multiple uses of - // edge values and re-computations of exprs - std::unordered_set used_edge_vals_set; - std::unordered_set exprs_set; - // Create new group auto joined_group = segmented_fusion_->newGroup(); - // Populate edges, exprs, global vals - // from each of the groups + // Track unique vals and exprs to avoid duplicates + std::unordered_set used_edge_vals_set; + std::unordered_set exprs_set; + + // Merge inputs and outputs from all groups for (auto group : groups_to_merge) { - // Populate complete fusion inputs to the group joined_group->input_vals_.pushBack(group->input_vals_); joined_group->output_vals_.pushBack(group->output_vals_); + } - // Populate producer edges to the group - for (auto edge : group->producer_edges) { - if ( - // Check this is not internal edge - !group_set.count(edge->from) && - // Check this val has been added or not - !used_edge_vals_set.count(edge->val)) { - used_edge_vals_set.insert(edge->val); - auto new_producer_edge = - segmented_fusion_->newEdge(edge->from, joined_group, edge->val); - joined_group->producer_edges.push_back(new_producer_edge); - edge->from->consumer_edges.push_back(new_producer_edge); - } - } + // Get all edges that will connect to the new joined group + auto all_edges = getAllEdges(groups_to_merge); - // Populate consumer edges from the group - for (auto edge : group->consumer_edges) { - if ( - // Check this is not internal edge - !group_set.count(edge->to)) { - auto new_consumer_edge = - segmented_fusion_->newEdge(joined_group, edge->to, edge->val); - joined_group->consumer_edges.push_back(new_consumer_edge); - edge->to->producer_edges.push_back(new_consumer_edge); - } + // Connect all external edges to the new joined group + for (auto edge : all_edges) { + if (group_set.count(edge->from)) { + // This is a consumer edge from the merged group + segmented_fusion_->connectGroups(joined_group, edge->to, edge->val); + } else { + // This is a producer edge to the merged group + segmented_fusion_->connectGroups(edge->from, joined_group, edge->val); } + } + + // Disconnect all original groups before connecting the new one, this + // invalidates all_edges + for (auto group : groups_to_merge) { + disconnectGroup(group); + } - // Populate exprs + // Merge all expressions from the groups + for (auto group : groups_to_merge) { for (auto expr : group->exprs_) { - if (!exprs_set.count(expr)) { + if (exprs_set.insert(expr).second) { joined_group->exprs_.push_back(expr); - exprs_set.insert(expr); } } } - // Clean up original groups from segmented fusion - for (auto group : groups_to_merge) { - auto disconnected_edges = disconnectGroup(group); - clean_up_edges_.insert( - disconnected_edges.begin(), disconnected_edges.end()); - } - - edges().erase( - std::remove_if( - edges().begin(), - edges().end(), - [this](SegmentedEdge* edge) { return clean_up_edges_.count(edge); }), - edges().end()); - + // Clean up original groups groups().erase( std::remove_if( groups().begin(), @@ -2267,8 +2228,6 @@ SegmentedGroup* SegmentCandidateFinder::mergeAllGivenGroups( }), groups().end()); - clean_up_edges_.clear(); - joined_group->setSchedulerType(deriveSchedulerType(joined_group)); return joined_group; } @@ -2324,8 +2283,9 @@ class FusionSegmentGuard : public NonCopyable { num_original_exprs_ = fusion_->exprs().size(); original_tvs_ = fusion_->allTvs(); #endif // NDEBUG - lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( - segmented_fusion_->edges()); + lowered_precision_edges_ = + segmented_fusion_->castInputOutputToLowerPrecision( + segmented_fusion_->edges()); } // Insert cast and narrow the fusion to a merged group of a and b @@ -2349,7 +2309,7 @@ class FusionSegmentGuard : public NonCopyable { consumer_edges.begin(), consumer_edges.end(), std::back_inserter(all_edges)); - lowered_edges_ = + lowered_precision_edges_ = segmented_fusion_->castInputOutputToLowerPrecision(all_edges, {a, b}); auto new_inputs = getAllInputs(a, b); @@ -2373,8 +2333,9 @@ class FusionSegmentGuard : public NonCopyable { // Cast inputs and outputs of a merged group consisting of // segmented_groups. auto all_edges = getAllEdges(segmented_groups); - lowered_edges_ = segmented_fusion_->castInputOutputToLowerPrecision( - all_edges, segmented_groups); + lowered_precision_edges_ = + segmented_fusion_->castInputOutputToLowerPrecision( + all_edges, segmented_groups); auto new_inputs = allInputsIfTrueElseOutputs(segmented_groups, true); auto new_outputs = allInputsIfTrueElseOutputs(segmented_groups, false); @@ -2393,8 +2354,9 @@ class FusionSegmentGuard : public NonCopyable { restoreOriginalSegment(); // Revert the cast - if (segmented_fusion_ != nullptr && !lowered_edges_.empty()) { - segmented_fusion_->revertInputOutputPrecisionChanges(lowered_edges_); + if (segmented_fusion_ != nullptr && !lowered_precision_edges_.empty()) { + segmented_fusion_->revertInputOutputPrecisionChanges( + lowered_precision_edges_); } #ifndef NDEBUG @@ -2473,7 +2435,7 @@ class FusionSegmentGuard : public NonCopyable { Fusion* const fusion_ = nullptr; std::vector old_inputs_; std::vector old_outputs_; - std::vector lowered_edges_; + std::vector lowered_precision_edges_; #ifndef NDEBUG size_t num_original_exprs_ = 0; std::vector original_tvs_; @@ -2585,7 +2547,9 @@ std::vector SegmentedGroup::stablyOrderedExprs() const { std::unordered_map num_producers; for (Expr* e : exprs()) { int64_t& n = num_producers[e]; - for (Val* in : e->inputs()) { + // Val::uses(), which is used later to decrement num_producers, contains + // unique `Expr`s. Therefore, it's necessary to also dedup here. + for (auto* in : VectorOfUniqueEntries(e->inputs())) { Expr* def = in->definition(); // Exprs in a SegmentedGroup come from the complete fusion, so the // producer/consumer of an Expr may be outside the group. Therefore, we @@ -3352,7 +3316,6 @@ class CombineReductions { return groups_to_merge_set.has(group); }), groups_with_reductions_.end()); - return joined_group; } } @@ -3657,6 +3620,10 @@ class MergeUpAndDownCast { SegmentedGroup* group = to_visit.front(); to_visit.pop_front(); + if (group->exprs().empty()) { + continue; + } + if (groups_to_merge_set.count(group)) { continue; } @@ -3749,27 +3716,6 @@ class MergeUpAndDownCast { namespace { -//! Returns true if group1 and group2 are an immediate producer-consumer pair. -bool areDirectlyConnected(SegmentedGroup* group1, SegmentedGroup* group2) { - // Check if group1 is a immediate consumer of group2 - if (std::any_of( - group1->producer_edges.begin(), - group1->producer_edges.end(), - [group2](SegmentedEdge* edge) { return edge->from == group2; })) { - return true; - } - - // Check if group1 is a immediate producer of group2 - if (std::any_of( - group1->consumer_edges.begin(), - group1->consumer_edges.end(), - [group2](SegmentedEdge* edge) { return edge->to == group2; })) { - return true; - } - - return false; -} - //! Allow the segmentation algorithm to prefer certain exprs to merge class PreferredMergeCandidatePicker { public: @@ -3879,7 +3825,8 @@ bool SegmentCandidateFinder::codeGenSupportedMerge( SegmentedGroup* group2) { FUSER_PERF_SCOPE("SegmentCandidateFinder::codeGenSupportedMerge"); NVF_ERROR( - areDirectlyConnected(group1, group2), + !segmented_fusion_->getEdgesBetween(group1, group2).empty() || + !segmented_fusion_->getEdgesBetween(group2, group1).empty(), "only support testing immediate producer-consumer groups"); // The segmemter should ideally be redesigned to be more flexible and // decoupled from the schedulers, but for now, we just return @@ -3979,9 +3926,7 @@ void SegmentCandidateFinder::buildInitialSegments() { if (isFusionInput(inp)) { expr_group->input_vals_.pushBack(inp); auto aux_group = input2group_.at(inp); - auto new_edge = segmented_fusion_->newEdge(aux_group, expr_group, inp); - expr_group->producer_edges.push_back(new_edge); - aux_group->consumer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(aux_group, expr_group, inp); continue; } @@ -3999,9 +3944,7 @@ void SegmentCandidateFinder::buildInitialSegments() { } auto def_group = expr2group.at(inp->definition()); - auto new_edge = segmented_fusion_->newEdge(def_group, expr_group, inp); - expr_group->producer_edges.push_back(new_edge); - def_group->consumer_edges.push_back(new_edge); + segmented_fusion_->connectGroups(def_group, expr_group, inp); } for (auto out : expr->outputs()) { if (out->isFusionOutput()) { @@ -4375,6 +4318,8 @@ void SegmentCandidateFinder::revertPrivatizedUpcast(SegmentedGroup* group) { maybe_deduplicate_edge(consumer_edge_to_update); } + std::erase(group->exprs_, uop); + // Note that it should not be necessary to do anything with // group->output_vals since the inserted upcast ops should never produce // fusion outputs. @@ -4432,12 +4377,28 @@ void SegmentCandidateFinder::forwardInputs() { excluded_inp_unary_exprs_ = {}; input2group_.clear(); + std::vector extended_fusion_inputs = completeFusion()->inputs(); + + // Grab factory ops that should be forwarded. Add created tensors to + // the fusion input list to make them handled like fusion inputs + // TODO: Handle more factory methods such as IotaOp, EyeOp, + // TensorConstruct. Probably should not include relatively expensive + // ops like RNGOp. + for (auto expr : completeFusion()->exprs()) { + if (expr->isA() && + // Don't bother if it's a fusion output + !expr->output(0)->isFusionOutput()) { + extended_fusion_inputs.push_back(expr->output(0)); + excluded_inp_unary_exprs_.pushBack(expr); + } + } + // "Terminating" outputs from the excluded input unary exprs, these will be // treated as complete fusion inputs. VectorOfUniqueEntries forwarded_inputs; { std::deque to_visit; - for (Val* inp : completeFusion()->inputs()) { + for (Val* inp : extended_fusion_inputs) { if (UnaryOp* unary_use = shouldForward(inp)) { to_visit.push_back(unary_use); } @@ -4460,11 +4421,13 @@ void SegmentCandidateFinder::forwardInputs() { } } - auto excluded_fusion_inputs = IterVisitor::getInputsTo( - {forwarded_inputs.begin(), forwarded_inputs.end()}); + // Stop traversing back at factory vals (and fusion inputs) + auto excluded_fusion_inputs = InputsOf::getInputsTo( + {forwarded_inputs.begin(), forwarded_inputs.end()}, + extended_fusion_inputs); // List of vals to treat as complete fusion inputs for segmentation - forwarded_fusion_inputs_ = completeFusion()->inputs(); + forwarded_fusion_inputs_ = extended_fusion_inputs; forwarded_fusion_inputs_.erase( std::remove_if( @@ -4503,6 +4466,16 @@ void SegmentCandidateFinder::cleanupForwardedInputs() { input2group_.clear(); } +std::vector SegmentCandidateFinder::getAuxiliaryInputGroups() + const { + std::vector aux_groups; + aux_groups.reserve(input2group_.size()); + std::ranges::transform(input2group_, aux_groups.begin(), [](const auto& kv) { + return kv.second; + }); + return aux_groups; +} + void SegmentCandidateFinder::finalMerge() { FUSER_PERF_SCOPE("SegmentCandidateFinder::finalMerge"); auto producer_check = getGroupDependency(); @@ -4715,7 +4688,14 @@ void SegmentCandidateFinder::resolveScalarsInGroup(SegmentedGroup* group) { SegmentedGroup* SegmentCandidateFinder::createInputGroup(Val* forwarded_input) { SegmentedGroup* group = segmented_fusion_->newGroup(); - group->input_vals_ = IterVisitor::getInputsTo({forwarded_input}); + for (auto inp : IterVisitor::getInputsTo({forwarded_input})) { + // inp may be a factory-created tensor, which is not an input to + // the group. + if (std::ranges::find(completeFusion()->inputs(), inp) != + completeFusion()->inputs().end()) { + group->input_vals_.pushBack(inp); + } + } group->exprs_ = StmtSort::getExprsTo({forwarded_input}); return group; } @@ -4725,26 +4705,29 @@ void SegmentCandidateFinder::resolveNonscalarForwardedInput( SegmentedGroup* aux_group = input2group_.at(forwarded_input); NVF_ERROR(aux_group->producer_edges.empty()); - // use unordered_set to avoid duplicated group in consumers. - // duplicated entry in consumer would make use call - // codeGenSupportedMerge(input_group, consumer) twice. Where the second time - // the connection has already been severed by mergeNodes(). GroupSet consumers; for (SegmentedEdge* edge : aux_group->consumer_edges) { consumers.pushBack(edge->to); } - aux_group->consumer_edges.clear(); for (SegmentedGroup* consumer : consumers) { SegmentedGroup* input_group = createInputGroup(forwarded_input); - - for (SegmentedEdge*& edge : consumer->producer_edges) { + std::vector edges_to_remove; + std::vector producer_edge_copy = consumer->producer_edges; + // Use a copy to iterate over edges as connect group can invalidate the + // original iterator + for (SegmentedEdge* edge : producer_edge_copy) { if (edge->from == aux_group && edge->val == forwarded_input) { - edge->from = input_group; - input_group->consumer_edges.push_back(edge); + // Create new edges before removing old ones + segmented_fusion_->connectGroups( + input_group, consumer, forwarded_input); + // Now safe to remove old edges + edges_to_remove.push_back(edge); } } - + for (auto edge_to_remove : edges_to_remove) { + segmented_fusion_->removeEdge(edge_to_remove); + } consumer->input_vals_.erase(forwarded_input); if (codeGenSupportedMerge(input_group, consumer)) { @@ -4763,21 +4746,18 @@ void SegmentCandidateFinder::removeScalarEdges() { // translation. // we will not need them after scalar // resolution - auto remove_scalar_edges_from_vec = [](std::vector& edges) { - edges.erase( - std::remove_if( - edges.begin(), - edges.end(), - [](SegmentedEdge* segmented_edge) { - return segmented_edge->val->isScalar(); - }), - edges.end()); - }; - remove_scalar_edges_from_vec(edges()); - for (auto group : groups()) { - remove_scalar_edges_from_vec(group->producer_edges); - remove_scalar_edges_from_vec(group->consumer_edges); + // Collect all scalar edges first since removeEdge modifies the edge lists + std::vector scalar_edges; + for (auto edge : edges()) { + if (edge->val->isScalar()) { + scalar_edges.push_back(edge); + } + } + + // Remove each scalar edge through the proper API + for (auto edge : scalar_edges) { + segmented_fusion_->removeEdge(edge); } } diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index 6d6306ded2d..cfcbf75fd5d 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -151,8 +151,10 @@ class SegmentedGroup { std::optional> getMaybeHeuristicParams( SchedulerRuntimeInfo& runtime_info); - //! Query if this is a group for a fusion input - bool isFusionInputGroup() const; + //! Get the SegmentedFusion this group belongs to + const SegmentedFusion* segmentedFusion() const { + return segmented_fusion_; + } public: //! "Ancestor nodes", towards inputs of segmentedDAG @@ -192,9 +194,6 @@ class SegmentedGroup { //! Theorem 4.2 int level_ = -1; - //! traversal marker, has this node already been processed - bool visited_ = false; - //! Did we select another group to merge with SegmentedGroup* merge_with_ = nullptr; @@ -205,13 +204,6 @@ class SegmentedGroup { bool merged_ = false; private: - //! Utility to convert edge vector to value vector - std::vector edgesToVals(const std::vector& se_v); - - //! Reset method to call at begining of each - //! merge node iteration - void clearTraversalInfo(); - //! To be called at the very end of segment fusion //! no more segment merging should be done beyond void finalize(); @@ -333,6 +325,13 @@ class SegmentedFusion { //! API for adding edges SegmentedEdge* newEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); + //! Remove an edge from the segmented fusion graph and update all affected + //! groups The edge object will be deleted and should not be used after this + //! call + void removeEdge(SegmentedEdge* edge); + + void connectGroups(SegmentedGroup* from, SegmentedGroup* to, Val* val); + HeuristicDataCache* getCachedHeuristicDataFor(SegmentedGroup* group); //! Lower FP precision of inputs and outputs specified by the given @@ -363,6 +362,11 @@ class SegmentedFusion { //! Grab edges with val std::vector getEdgesByVal(Val* val) const; + //! Get edges between two groups + std::vector getEdgesBetween( + const SegmentedGroup* from, + const SegmentedGroup* to) const; + //! Serialize SegmentedFusion using flatbuffers flatbuffers::Offset serialize( flatbuffers::FlatBufferBuilder& builder) const; @@ -400,7 +404,6 @@ class SegmentedFusion { SegmentedGroup* makeGroup(); SegmentedGroup* makeGroup(Expr*); - SegmentedGroup* makeFusionInputGroup(); SegmentedEdge* makeEdge(SegmentedGroup* from, SegmentedGroup* to, Val* val); void cleanUnused(); std::unordered_map groups_map() const; @@ -580,7 +583,7 @@ class SegmentCandidateFinder { SegmentedGroup* group, std::vector candidates = {}); - std::unordered_set disconnectGroup(SegmentedGroup* group); + void disconnectGroup(SegmentedGroup* group); std::vector& groups() { NVF_ERROR( @@ -690,6 +693,9 @@ class SegmentCandidateFinder { val) != forwarded_fusion_inputs_.end(); }; + // Get all auxiliary groups created for fusion inputs + std::vector getAuxiliaryInputGroups() const; + protected: //! These are the merge node heuristic passes, should //! eventually should have a dedicated interface @@ -701,7 +707,6 @@ class SegmentCandidateFinder { SegmentCandidateFinderOptions options_; std::unordered_set clean_up_groups_; - std::unordered_set clean_up_edges_; std::vector to_merge_; diff --git a/csrc/host_ir/container.cpp b/csrc/host_ir/container.cpp index 83e668770fc..9fdcfa376a6 100644 --- a/csrc/host_ir/container.cpp +++ b/csrc/host_ir/container.cpp @@ -35,11 +35,13 @@ Stream* HostIrContainer::getDefaultStream() { std::ostream& HostIrContainer::print(std::ostream& os) const { IrMathPrinter op_exprs(os); op_exprs.handle(this); - os << "Aliases:{"; - for (const auto& alias : alias_) { - os << "\n " << alias.first << " -> " << alias.second; + if (alias_.size() > 0) { + os << "Aliases:{"; + for (const auto& alias : alias_) { + os << "\n " << alias.first << " -> " << alias.second; + } + os << "\n}\n"; } - os << "\n}\n"; return os; } diff --git a/csrc/host_ir/executor.cpp b/csrc/host_ir/executor.cpp index 3a3c0921d2a..2f2cf9e7b92 100644 --- a/csrc/host_ir/executor.cpp +++ b/csrc/host_ir/executor.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include #include @@ -640,17 +641,44 @@ void HostIrEvaluator::handle(LoadStoreOp* load_store_op) { auto* out_tv = load_store_op->out()->as(); auto in_tensor = getKnownConcreteValue(load_store_op->in()).as(); - // If output has root domain, it means that the set op is a permute, which we - // don't support currently - NVF_ERROR( - !out_tv->hasRoot(), "the set op", load_store_op, "must not be a permute"); - - if (!isKnown(load_store_op->out())) { - bind(load_store_op->out(), in_tensor); + at::Tensor t; + if (out_tv->hasRoot()) { + std::optional> permutation = + ir_utils::computePermutation( + out_tv->getRootDomain(), out_tv->getLogicalDomain()); + NVF_ERROR( + permutation.has_value(), + "The logical domain of a Set.Permute is supposed to be a permutation" + " of the root domain: ", + out_tv); + t = in_tensor.permute(*permutation); } else { + t = in_tensor; + } + + if (isKnown(out_tv)) { auto out_tensor = getKnownConcreteValue(load_store_op->out()).as(); - out_tensor.copy_(in_tensor); + out_tensor.copy_(t); + } else { + // For completeness, we may check if out_tv's allocation matches `t` and + // copy data if yes. For example, + // + // clang-format off + // ``` + // const auto& [sizes, strides] = inferShapeOfOutput(out_tv, expr_evaluator_); + // if (strides == t.strides()) { + // bind(out_tv, t); + // } else { + // auto out_tensor = at::empty_strided(sizes, strides, in_tensor.dtype()); + // out_tensor.copy_(t); + // bind_(out_tv, out_tensor); + // } + // ``` + // clang-format on + // + // For now, I choose to keep code simple for the limited use cases. + bind(out_tv, t); } } @@ -677,6 +705,15 @@ void HostIrEvaluator::handle(kir::Allocate* allocate) { bind(tv, tensor); } +void HostIrEvaluator::handle(HirAliasSelect* hir_alias_select) { + auto index = + expr_evaluator_.evaluate(hir_alias_select->index()).as(); + auto input = getKnownConcreteValue(hir_alias_select->in()->as()) + .as(); + int64_t axis = hir_alias_select->axis(); + bind(hir_alias_select->out(), input.select(axis, index)); +} + void HostIrEvaluator::handle(BinaryOp* binary_op) { if (!isKnown(binary_op->outputs().at(0))) { return unhandled(binary_op); diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index dfe84fba068..3f147b7801b 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -97,6 +97,10 @@ class HostIrEvaluator final : public OptOutDispatch { return container_->outputs(); } + auto* container() const { + return container_.get(); + } + std::ostream& print(std::ostream& os) const { return container_->print(os); }; @@ -138,6 +142,7 @@ class HostIrEvaluator final : public OptOutDispatch { void handle(BinaryOp* binary_op) override; void handle(ReductionOp* reduction_op) override; void handle(ShareMemHandles* share_mem_handles) override; + void handle(HirAliasSelect* hir_alias_select) override; void unhandled(Statement* stmt) override; c10::cuda::CUDAStream getCUDAStream(Stream* stream); diff --git a/csrc/host_ir/host_ir.cpp b/csrc/host_ir/host_ir.cpp index 9e1386d0d3d..bf3d5cef9eb 100644 --- a/csrc/host_ir/host_ir.cpp +++ b/csrc/host_ir/host_ir.cpp @@ -355,6 +355,51 @@ std::string ShareMemHandles::toInlineString(int indent_size) const { NVF_THROW("Cannot be printed inline"); } +HirAliasSelect::HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index) + : Expr(passkey, {in, index}, {}, {}) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + this, + "must be registered in a HostIrContainer"); + NVF_ERROR( + static_cast(in->getLogicalDomain().size()) > axis, + "Select axis ", + axis, + " is out of bounds for tensor ", + in->toString(), + " with ", + in->getLogicalDomain().size(), + " dimensions"); + // "out" is not added as an output because the current op doesn't "define" it, + // but rather sets its allocation. Since "out" will be used in another + // producing expression, this avoids unnecessary cyclic dependencies. This + // ressembles how kir::Allocate treats its allocated TensorView. + addAttribute(out); + addDataAttribute(axis); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(HirAliasSelect) + +std::string HirAliasSelect::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << out()->toString() << "\n"; + indent_size++; + indent(ss, indent_size) << " = HirAliasSelect( " << in()->toString() + << ", axis = " << in()->getLogicalDomain().at(axis()) + << ", index = " << index()->toString() << " )\n"; + return ss.str(); +} + +std::string HirAliasSelect::toInlineString(int indent_size) const { + NVF_THROW("Cannot be printed inline"); +} + } // namespace hir } // namespace nvfuser diff --git a/csrc/host_ir/host_ir.h b/csrc/host_ir/host_ir.h index bad3a6ef722..d267d23ab1f 100644 --- a/csrc/host_ir/host_ir.h +++ b/csrc/host_ir/host_ir.h @@ -351,6 +351,49 @@ class ShareMemHandles : public Expr { } }; +// This op mimicks the semantics of SelectOp but is used in HIR non-SSA context +// to index into a TensorView, returning an alias "slice" of the original +// TensorView. +class HirAliasSelect : public Expr { + public: + using Expr::Expr; + HirAliasSelect( + IrBuilderPasskey passkey, + TensorView* in, + TensorView* out, + int64_t axis, + Val* index); + + HirAliasSelect(const HirAliasSelect& other) = delete; + HirAliasSelect& operator=(const HirAliasSelect& other) = delete; + HirAliasSelect(HirAliasSelect&& other) = delete; + HirAliasSelect& operator=(HirAliasSelect&& other) = delete; + + NVFUSER_DECLARE_CLONE_AND_CREATE + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + const char* getOpString() const override { + return "hir::HirAliasSelect"; + } + + TensorView* in() const { + return inputs().at(0)->as(); + } + + TensorView* out() const { + return attributeVal(0)->as(); + } + + int64_t axis() const { + return attribute(1); + } + + Val* index() const { + return inputs().at(1); + } +}; + } // namespace hir } // namespace nvfuser diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index fd14096b190..ca9bb80ae4e 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -735,6 +736,13 @@ std::unique_ptr HostIrLower::lower( hic->addOutput(ir_cloner.clone(output)); } + for (auto tv : hic->allTvs()) { + // set all host tensors to global memory type. This must be the case by + // definition of a host tensor, and setting the memory type to global is + // also required to avoid Allocate HIR nodes to throw + tv->setMemoryType(MemoryType::Global); + } + std::vector new_top_level_exprs; for (auto top_level_expr : hic->topLevelExprs()) { if (!isResharding(top_level_expr)) { @@ -761,6 +769,8 @@ std::unique_ptr HostIrLower::lower( } hic->resetTopLevelExprs(new_top_level_exprs); + preseg_passes::OptimizationPass::runPass(hic.get()); + return hic; } diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp new file mode 100644 index 00000000000..d7bfa0f090a --- /dev/null +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -0,0 +1,440 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::hir { + +namespace { + +// Finds the stream axis in a tensor's domain. There should be at most one +// stream axis. +IterDomain* getStreamAxis(const std::vector& domain) { + IterDomain* ret = nullptr; + for (auto id : domain) { + if (id->getParallelType() == ParallelType::Stream) { + NVF_CHECK( + ret == nullptr, + "Expected at most one stream axis in the domain, but found ", + id, + " and ", + ret); + ret = id; + } + } + return ret; +} + +// Validates that a stream axis is valid in a tensor +void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { + // Find the stream axis in the logical domain + auto it_logical_stream_axis = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + stream_axis); + + // Verify stream axis is not split/merged + NVF_ERROR( + it_logical_stream_axis != tv->getLogicalDomain().end(), + "Cannot stream parallelize on a split/merge axis ", + stream_axis); + + // Verify stream axis is an iteration or broadcast axis + NVF_CHECK( + stream_axis->getIterType() == IterType::Iteration || + stream_axis->getIterType() == IterType::Broadcast, + "Stream axis ", + stream_axis, + " should be an iteration or broadcast axis."); +} + +// Checks if two iteration domains are mapped in the ID model +bool areIdsMapped(const IdModel& id_model, IterDomain* id1, IterDomain* id2) { + return id_model.idGraph(IdMappingMode::BROADCAST) + .disjointValSets() + .strictAreMapped(id1, id2); +} + +// Determines if a stream-parallel for-loop can be merged with the previous one +bool canMergeWithPreviousForLoop( + const std::vector& new_top_level_exprs, + IterDomain* stream_axis, + const IdModel& id_model) { + return !new_top_level_exprs.empty() && + new_top_level_exprs.back()->isA() && + areIdsMapped( + id_model, + stream_axis, + new_top_level_exprs.back()->as()->iterDomain()); +} + +// Finds where a stream axis appears in a tensor's logical domain +int64_t findStreamAxisIndex( + const TensorView* tv, + IterDomain* stream_axis, + const IdModel& id_model) { + int64_t stream_id_logical_index = -1; + for (auto id : tv->getLoopDomain()) { + if (areIdsMapped(id_model, stream_axis, id)) { + // Verify only one stream axis exists + NVF_CHECK( + stream_id_logical_index == -1, + "Expected at most one axis mapping to the stream axis ", + stream_axis, + " in the tensor ", + tv, + " loop's domain ", + tv->getLoopDomain()); + + // Find stream axis in logical domain + auto it_stream_id_logical = std::find( + tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), id); + NVF_CHECK( + it_stream_id_logical != tv->getLogicalDomain().end(), + "Expected to find ", + id, + " in ", + tv, + "'s logical domain ", + tv->getLogicalDomain()); + stream_id_logical_index = + std::distance(tv->getLogicalDomain().begin(), it_stream_id_logical); + } + } + return stream_id_logical_index; +} + +// Cache for tensor slicing operations in stream parallelization. +// This cache stores previously created sliced versions of tensors to avoid +// redundant slicing operations. A sliced tensor is created by removing a +// specific axis (stream axis) from the tensor's domain and creating a new +// tensor that represents a slice of the original tensor at a given index. +// The cache key is a tuple of (original tensor, axis index to remove, slice +// index). +struct TensorSlicingCache { + // Type aliases + using Key = std::tuple; + + // Custom hash function for the tuple used as cache key + struct Hash { + size_t operator()(const Key& t) const { + auto [tv, idx, val] = t; + return std::hash{}(tv) ^ std::hash{}(idx) ^ + std::hash{}(val); + } + }; + + // Map type for storing cached sliced tensors + using Map = std::unordered_map; + + // Get the expr producing the indexed version of a tensor. If the expr already + // exists in the cache, returns the cached version. Otherwise, creates a new + // expr, producing a tensor "selected" on its dimension `stream_axis_index` at + // index `index`. Returns a pair of (expr, is_new) where is_new indicates + // whether the expr was newly created. + std::pair get( + TensorView* tensor, + int64_t stream_axis_index, + Val* index) { + auto key = std::make_tuple(tensor, stream_axis_index, index); + auto it = cache_.find(key); + if (it != cache_.end()) { + return {it->second, false}; + } + + auto dom = tensor->getLogicalDomain(); + std::vector new_root; + new_root.reserve(dom.size() - 1); + + for (auto i : arange((int64_t)dom.size())) { + if (i != stream_axis_index) { + new_root.emplace_back(dom[i]->cloneWithoutRFactor()); + } + } + + auto td = IrBuilder::create( + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); + auto out = IrBuilder::create(td, *tensor->getDataType()); + auto result = IrBuilder::create( + tensor, out, stream_axis_index, index); + + cache_[key] = result; + return {result, true}; + } + + private: + Map cache_; // Storage for cached sliced tensors +}; + +// Step 1: Group expressions into stream-parallel regions +std::vector groupStreamParallelRegions( + const std::vector& top_level_exprs, + const IdModel& id_model) { + std::vector new_top_level_exprs; + + for (auto* expr : top_level_exprs) { + // Skip expressions with no outputs + if (expr->outputs().size() == 0) { + new_top_level_exprs.push_back(expr); + continue; + } + + // Each expression should have exactly one output + NVF_CHECK( + expr->outputs().size() == 1, + "Each expr should have at most one output."); + + // Get the output tensor and check for stream parallelization + TensorView* output = expr->output(0)->as(); + IterDomain* stream_axis = getStreamAxis(output->getLoopDomain()); + + // If no stream axis found, keep the expression as is + if (stream_axis == nullptr) { + new_top_level_exprs.push_back(expr); + continue; + } + + // Verify that the expression can be handled as a standalone host operation + NVF_ERROR( + HostIrLower::isLowerableAsStandaloneHostOp(expr), + "Stream parallel type not supported for expr ", + expr); + + // Validate stream axis + validateStreamAxis(stream_axis, output); + + // Check if we can merge this expression with the previous for-loop + if (canMergeWithPreviousForLoop( + new_top_level_exprs, stream_axis, id_model)) { + // Merge with existing for-loop by adding the expression to its body + new_top_level_exprs.back()->as()->body().push_back(expr); + } else { + // Create a new for-loop for stream parallelization + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream), + /*start=*/FusionGuard::getCurFusion()->zeroVal(), + /*stop=*/stream_axis->extent(), + /*step=*/FusionGuard::getCurFusion()->oneVal(), + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + // Add the expression to the new for-loop's body + for_loop->body().push_back(expr); + new_top_level_exprs.push_back(for_loop); + } + } + + return new_top_level_exprs; +} + +// Helper function to add allocations for tensors that need them +std::vector addTensorAllocations( + std::vector top_level_exprs, + const IdModel& id_model) { + std::vector new_top_level_exprs; + + for (auto* expr : top_level_exprs) { + if (expr->isA()) { + // add allocations for tensors produced in the loop that have a stream + // axes + auto* for_loop = expr->as(); + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) != + -1) { + new_top_level_exprs.push_back( + IrBuilder::create(output, MemoryType::Global)); + } + } + } + } + new_top_level_exprs.push_back(expr); + } + + return new_top_level_exprs; +} + +// Step 3: Process for-loop bodies by slicing tensors +std::vector processForLoopBodies( + std::vector top_level_exprs, + const IdModel& id_model) { + TensorSlicingCache tensor_slicing_cache; + + for (auto* expr : top_level_exprs) { + if (!expr->isA()) { + continue; + } + + auto* for_loop = expr->as(); + std::vector new_loop_body; + + // Lambda to process a tensor in a for-loop body + auto processTensor = [&](Expr*& expr, TensorView* tensor) { + if (auto stream_idx = + findStreamAxisIndex(tensor, for_loop->iterDomain(), id_model); + stream_idx != -1) { + auto [slicing, is_new] = + tensor_slicing_cache.get(tensor, stream_idx, for_loop->index()); + if (is_new) { + new_loop_body.push_back(slicing); + } + expr = ir_utils::replaceValInExprInputs(expr, tensor, slicing->out()); + if (expr->outputs().size() > 0 && expr->outputs()[0] == tensor) { + expr = + ir_utils::transferDefinitionToNewOutputs(expr, {slicing->out()}); + } + } + }; + + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* input : + ir_utils::filterByType(body_expr->inputs())) { + processTensor(body_expr, input); + } + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + processTensor(body_expr, output); + } + new_loop_body.push_back(body_expr); + } + + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + } + + return top_level_exprs; +} + +// Step 4: Add stream management and synchronization +std::vector addStreamManagement(std::vector top_level_exprs) { + // Process each top-level expression + for (auto* top_level_expr : top_level_exprs) { + // Skip non-for-loop expressions + if (!top_level_expr->isA()) { + continue; + } + + auto* for_loop = top_level_expr->as(); + std::vector new_loop_body; + + // Get the current stream before entering the loop + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + new_loop_body.push_back(get_current_stream); + + // Set up a new stream for this iteration based on the loop index + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(for_loop->index(), number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + new_loop_body.push_back(set_stream); + + // Synchronize with the original stream before starting computation + auto* initial_sync_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(initial_sync_stream); + + // Add all the expressions to the loop body + for (auto* expr : for_loop->body().exprs()) { + new_loop_body.push_back(expr); + } + + // Restore the original stream and synchronize with the iteration's stream + auto* set_back_original_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(set_back_original_stream); + auto* sync_stream = IrBuilder::create(stream); + new_loop_body.push_back(sync_stream); + + // Update the for-loop body with the new expressions + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + } + + return top_level_exprs; +} + +} // anonymous namespace + +// StreamParallelType pass implementation. +// This pass handles stream parallelization of operations in a fusion. +// It works by: +// 1. Identifying stream-parallelized axes in tensor operations +// 2. Grouping compatible operations into stream-parallel for-loops +// 3. Setting up proper stream synchronization and management +// 4. Adding allocations for tensors that need them +// The pass ensures that: +// - Input tensors don't have stream axes +// - Only one stream axis exists per tensor +// - Stream axes are properly synchronized +// - Operations are correctly grouped into stream-parallel regions +// - The resulting HostIrContainer's top level expression is valid for execution +// and does not contain any stream axes +// +// TODO: Here, we assume that the fusion input is a HostIrContainer and use the +// linear structure of the HostIrContainer::topLevelExpr to greedily merge the +// adjacent compatible stream for-loop bodies. Ideally we should look at the dag +// and use the segmenter. +void StreamParallelType::runPass(Fusion* fusion) { + // Verify that input tensors don't have stream axes + NVF_CHECK( + std::all_of( + fusion->inputs().begin(), + fusion->inputs().end(), + [](Val* input) { + auto input_tv = dynamic_cast(input); + return input_tv == nullptr || + getStreamAxis(input_tv->getLoopDomain()) == nullptr; + }), + "Expected no stream axis in the TensorView inputs."); + + // Set up the fusion environment and build the ID model + FusionGuard fg(fusion); + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); + + IdModel id_model(fusion); + id_model.buildBroadcastGraph(); + + // Step 1: Group expressions into stream-parallel regions + std::vector top_level_exprs = + groupStreamParallelRegions(hic->topLevelExprs(), id_model); + + // Step 2: Add allocations for tensors that need them + top_level_exprs = addTensorAllocations(std::move(top_level_exprs), id_model); + + // Step 3: Process for-loop bodies by slicing tensors + top_level_exprs = processForLoopBodies(std::move(top_level_exprs), id_model); + + // Step 4: Add stream management and synchronization + top_level_exprs = addStreamManagement(std::move(top_level_exprs)); + + // Update the container's top-level expressions + hic->resetTopLevelExprs(top_level_exprs); +} + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/pass/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h new file mode 100644 index 00000000000..8b5f138ad7e --- /dev/null +++ b/csrc/host_ir/pass/stream_parallel_type.h @@ -0,0 +1,36 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::hir { + +// A pass used in HostIrLower that takes a HostIrContainer as input, reads the +// TensorView's ParallelType::Stream, and modify the the HostIrContainer's top +// level expressions with the corresponding Host For Loops, which bodies contain +// stream assignement, selecting on tensor's axis, and the exprs on those sliced +// tensors. After this pass, the ParallelType::Stream is removed from the +// TensorView's axis. +// +// An illustration of the pass can be found in the tests +// `test_host_ir_stream_lowering.cpp` +// with the option `NVFUSER_DUMP=host_ir`. +class StreamParallelType + : public preseg_passes::OptimizationPass { + friend class preseg_passes::OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static constexpr std::string_view name() { + return "StreamParallelType"; + } +}; + +} // namespace nvfuser::hir diff --git a/csrc/ir/interface_nodes.h b/csrc/ir/interface_nodes.h index cffd277eeb2..0445339012d 100644 --- a/csrc/ir/interface_nodes.h +++ b/csrc/ir/interface_nodes.h @@ -192,7 +192,7 @@ class TVDomainGuard; // if (threadIdx.y == blockDim.y - 1) { // // If we use warp specialization on TIDy, then the blockDim.y of the // // kernel will be (whatever_value_inferred_from_schedule + 1), and the -// // last threadIdx.y will be used as load warp +// // last threadIdx.y will be used as async warp // for i in range(data.size): // wait buffer[i % stage] to be empty // load data[i] to buffer[i % stage] @@ -256,7 +256,7 @@ struct WarpSpecialized { validate_num_registers(num_registers.value().second); NVF_ERROR( num_registers.value().first <= num_registers.value().second, - "The number of registers for load warp group must be <= to the number", + "The number of registers for async warp group must be <= to the number", " of registers for the compute warp groups."); } diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index 91d3ca4ec39..066823e42c9 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -129,6 +129,62 @@ class IndexSelectOp : public Expr { } }; +class IndexPutAccumulateOp : public Expr { + public: + using Expr::Expr; + + // [ Note -- IndexPutAccumulateOp semantics ] + // + // logical ID groups of IndexPutAccumulateOp + // args: + // acc [ ID_indexed_g0, ID_g0 ] + // index [ ID_indexing_g1, ID_broadcast ] + // value [ ID_indexing_g1, ID_g0 ] + // output: + // out [ ID_indexed_g0, ID_g0 ] + // + // Note that: + // 1. indexed ID for `out` and `acc` share the same extent. + // 2. indexed ID for `index` and `value` share the same extent. + IndexPutAccumulateOp( + IrBuilderPasskey, + Val* out, + Val* acc, + Val* index, + Val* value); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "IndexPutAccumulateOp"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; + std::vector evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const override; + + TensorView* accumulateTv() const { + return input(0)->as(); + } + + TensorView* indexTv() const { + return input(1)->as(); + } + + TensorView* valueTv() const { + return input(2)->as(); + } + + // return ID_indexing_g1 from value + IterDomain* getIndexingIDOfValue() const; + + // return ID_indexing_g1 from index, for IndexPutAccumulate, there's only one + // indexing ID, while the remaining ID is broadcast + IterDomain* getIndexingID() const; +}; + class NVF_API GatherOp : public Expr { public: using Expr::Expr; @@ -2477,6 +2533,10 @@ class ForLoop final : public Expr { return input(0); } + IterDomain* iterDomain() const { + return input(1)->as(); + } + Val* indexOrStartIfTrivial() const { return isTrivial() ? start() : index(); } diff --git a/csrc/ir/iostream.cpp b/csrc/ir/iostream.cpp index 2b7e0c8f57a..59e01acbeb5 100644 --- a/csrc/ir/iostream.cpp +++ b/csrc/ir/iostream.cpp @@ -152,15 +152,10 @@ void IrTransformPrinter::printTransforms(const TensorView* tv) { os() << " contiguity: " << tv->domain()->getContiguityString() << "\n"; - const auto& from = tv->getLogicalDomain(); - const auto& loop = tv->getLoopDomain(); - const auto all_exp = DependencyCheck::getAllExprsBetween( - {from.begin(), from.end()}, {loop.begin(), loop.end()}); - - for (const auto exp : all_exp) { + for (const auto exp : tv->domain()->allExprs()) { os() << " " << exp->toString(); } - os() << " loop domain : (" << toDelimitedString(loop) << ")\n"; + os() << " loop domain : (" << toDelimitedString(tv->getLoopDomain()) << ")\n"; } std::ostream& operator<<(std::ostream& os, const Statement* stmt) { diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index de40861cb6c..577168d14cf 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -177,6 +177,53 @@ std::vector IndexSelectOp::evaluate( NVFUSER_DEFINE_CLONE_AND_CREATE(IndexSelectOp) +IndexPutAccumulateOp::IndexPutAccumulateOp( + IrBuilderPasskey passkey, + Val* out, + Val* acc, + Val* index, + Val* value) + : Expr(passkey) { + addInput(acc); + addInput(index); + addInput(value); + addOutput(out); +} + +std::string IndexPutAccumulateOp::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << output(0)->toString() << "\n"; + indent_size++; + indent(ss, indent_size) << " = indexPutAccumulate( "; + ss << input(0)->toString() << ", " << input(1)->toString() << ", " + << input(2)->toString() << " )\n"; + return ss.str(); +} + +std::string IndexPutAccumulateOp::toInlineString(int indent_size) const { + NVF_CHECK(false, "Tensor op can not be printed inline"); +} + +IterDomain* IndexPutAccumulateOp::getIndexingIDOfValue() const { + return TensorDomain::noReductions(valueTv()->getLogicalDomain()).front(); +} + +IterDomain* IndexPutAccumulateOp::getIndexingID() const { + return TensorDomain::noReductions(indexTv()->getLogicalDomain()).front(); +} + +std::vector IndexPutAccumulateOp::evaluate( + const ExpressionEvaluator& ee, + const std::vector& inputs) const { + return {at::index_put( + /*self=*/inputs.at(0).as(), + /*indices=*/{inputs.at(1).as()}, + /*values=*/inputs.at(2).as(), + /*accumulate=*/true)}; +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(IndexPutAccumulateOp) + GatherOp::GatherOp( IrBuilderPasskey passkey, Val* out, diff --git a/csrc/kernel_ir.cpp b/csrc/kernel_ir.cpp index b2d30c93b02..efc60b809e6 100644 --- a/csrc/kernel_ir.cpp +++ b/csrc/kernel_ir.cpp @@ -432,6 +432,8 @@ std::string Asm::utility() const { "tcgen05::relinquishAllocPermit"}, {"tcgen05.dealloc.cta_group::1.sync.aligned.b32", "tcgen05::dealloc"}, {"tcgen05.mma.cta_group::1.kind::f16", "tcgen05::mma_f16"}, + {"tcgen05.commit.cta_group::1.mbarrier::arrive::one.shared::cluster.b64", + "tcgen05::commit"}, {"wgmma.fence.sync.aligned", "wgmma::fence"}, {"fence.proxy.async", "fenceAsyncProxy"}, {"wgmma.commit_group.sync.aligned", "wgmma::commit"}, @@ -457,7 +459,7 @@ std::string Asm::utility() const { std::regex ld_pattern(R"(tcgen05\.ld\.sync\.aligned\.([^.]+)\.x\d+\.b32)"); std::smatch match; if (std::regex_match(code, match, ld_pattern)) { - std::string result = "tmem::load"; + std::string result = "tcgen05::load"; result.append(match[1]); return result; } @@ -466,7 +468,7 @@ std::string Asm::utility() const { std::regex st_pattern(R"(tcgen05\.st\.sync\.aligned\.([^.]+)\.x\d+\.b32)"); std::smatch match; if (std::regex_match(code, match, st_pattern)) { - std::string result = "tmem::store"; + std::string result = "tcgen05::store"; result.append(match[1]); return result; } @@ -670,6 +672,25 @@ std::string SetMaxNReg::toInlineString(int indent_size) const { NVFUSER_DEFINE_CLONE_AND_CREATE(SetMaxNReg) +Continue::Continue(IrBuilderPasskey passkey) : Expr(passkey) { + NVF_ERROR(passkey.ir_container_ != nullptr); + NVF_ERROR( + passkey.ir_container_->isA(), + "IR type only valid for Kernel container."); +} + +std::string Continue::toString(int indent_size) const { + std::stringstream ss; + indent(ss, indent_size) << "continue\n"; + return ss.str(); +} + +std::string Continue::toInlineString(int indent_size) const { + NVF_CHECK(false, "Continue can not be printed inline"); +} + +NVFUSER_DEFINE_CLONE_AND_CREATE(Continue) + Return::Return(IrBuilderPasskey passkey) : Expr(passkey) { NVF_ERROR(passkey.ir_container_ != nullptr); NVF_ERROR( diff --git a/csrc/kernel_ir.h b/csrc/kernel_ir.h index 5b08664774d..a255b389cbb 100644 --- a/csrc/kernel_ir.h +++ b/csrc/kernel_ir.h @@ -42,6 +42,7 @@ class GridSync; class FenceAsyncProxy; class WgMmaFence; class SetMaxNReg; +class Continue; class Return; class MBarrierInit; class MBarrierInvalidate; @@ -515,7 +516,7 @@ class BlockSync final : public Expr { return attribute>(1).value_or(false); } - bool isLoadWarpSync() const { + bool isAsyncWarpSync() const { auto optional_compute_or_load_sync = attribute>(1); return optional_compute_or_load_sync.has_value() && !optional_compute_or_load_sync.value(); @@ -613,6 +614,22 @@ class SetMaxNReg final : public Expr { } }; +class Continue final : public Expr { + public: + using Expr::Expr; + + explicit Continue(IrBuilderPasskey passkey); + + NVFUSER_DECLARE_CLONE_AND_CREATE + + const char* getOpString() const override { + return "Continue"; + } + + std::string toString(int indent_size = 0) const override; + std::string toInlineString(int indent_size = 0) const override; +}; + class Return final : public Expr { public: using Expr::Expr; diff --git a/csrc/logical_domain_map.cpp b/csrc/logical_domain_map.cpp index 9092fe44ffa..8e1cc09b527 100644 --- a/csrc/logical_domain_map.cpp +++ b/csrc/logical_domain_map.cpp @@ -74,32 +74,47 @@ PairwiseLogicalDomainMap::PairwiseLogicalDomainMap( namespace { -// Returns a producer ID that is indirectly accessed. A bool is also -// returned indicating there's a corresponding consumer ID. For -// example, select doesn't have a consumer ID, whereas index_select -// does. -std::pair getIndexedDomainInfo( +// Returns producer IDs that don't map identically to consumer. A bool is +// returned indicating whether corresponding consumer IDs exists. For example, +// select doesn't have a consumer ID, whereas index_select does. +std::pair, bool> getNonMappingDomainInfo( const TensorView* producer_tv, const TensorView* consumer_tv) { - IterDomain* indexed_id = nullptr; + std::unordered_set non_mapping_ids; bool has_consumer_id = false; if (auto sop = dynamic_cast(consumer_tv->definition())) { - indexed_id = sop->getIndexedID(); + // indexed ID is indirectly accessed + non_mapping_ids.insert(sop->getIndexedID()); has_consumer_id = false; } else if ( auto sop = dynamic_cast(consumer_tv->definition())) { + // indexed ID is indirectly accessed if (producer_tv == sop->lookupTv()) { - indexed_id = sop->getIndexedID(); + non_mapping_ids.insert(sop->getIndexedID()); has_consumer_id = true; } } else if (auto gop = dynamic_cast(consumer_tv->definition())) { + // indexed ID is indirectly accessed if (producer_tv == gop->lookupTv()) { - indexed_id = gop->getIndexedID(); + non_mapping_ids.insert(gop->getIndexedID()); + has_consumer_id = true; + } + } else if ( + auto iaop = + dynamic_cast(consumer_tv->definition())) { + // see [ Note -- IndexPutAccumulateOp semantics ] + if (producer_tv == iaop->indexTv()) { + // Indexing ID of index tv do not map to output. + non_mapping_ids.insert(iaop->getIndexingID()); + has_consumer_id = true; + } else if (producer_tv == iaop->valueTv()) { + // indexing ID of value tv do not map to output. + non_mapping_ids.insert(iaop->getIndexingIDOfValue()); has_consumer_id = true; } } - return std::make_pair(indexed_id, has_consumer_id); + return std::make_pair(non_mapping_ids, has_consumer_id); } } // namespace @@ -120,8 +135,8 @@ std::unordered_map PairwiseLogicalDomainMap::map( squeeze_flags = sop->getSqueezeDimFlags(); } - auto [indexed_producer_id, has_consumer_of_indexed_id] = - getIndexedDomainInfo(producer_tv_, consumer_tv_); + auto [non_mapping_producer_id, has_consumer_of_indexed_id] = + getNonMappingDomainInfo(producer_tv_, consumer_tv_); std::unordered_map dom_map; const auto producer_logical = TensorDomain::noReductions(producer->logical()); @@ -339,13 +354,14 @@ std::unordered_map PairwiseLogicalDomainMap::map( IterDomain* consumer_id = consumer_root.at(itc); // Conditions to check: - // 1. Indirectly accessed IDs (e.g., select) + // 1. Non mapping IDs (e.g., select) // 2. IDs that may have different extents (e.g., non indexed // domains of torchGather) // 3. Squeeze and unsqueeze - // Condition 1: when the producer ID is the dim of a select-like op - if (producer_id == indexed_producer_id) { + // Condition 1: when the producer ID is the dim of a select-like op, or when + // it doesn't map to the output IDs, like indexing IDs of indexPutAccumulate + if (non_mapping_producer_id.count(producer_id) != 0) { // If there's no corresponding consumer, skip the indexed producer if (!has_consumer_of_indexed_id) { itp++; @@ -362,7 +378,8 @@ std::unordered_map PairwiseLogicalDomainMap::map( // Condition 2: Different extents if (auto gop = dynamic_cast(consumer_tv_->definition()); gop != nullptr && !gop->exactSizes() && - producer_tv_ == gop->lookupTv() && producer_id != indexed_producer_id && + producer_tv_ == gop->lookupTv() && + non_mapping_producer_id.count(producer_id) == 0 && !map_different_extents_) { itp++; itc++; diff --git a/csrc/multidevice/c10d_mock.h b/csrc/multidevice/c10d_mock.h index b4ac0152ada..3befb61323f 100644 --- a/csrc/multidevice/c10d_mock.h +++ b/csrc/multidevice/c10d_mock.h @@ -5,6 +5,19 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on + +// This file provides a mock implementation of c10d that builds but doesn't +// function. +// +// nvFuser is sometimes built on a pytorch without c10d. When that +// happens, c10d isn't linked, NVFUSER_DISTRIBUTED is undefined and the +// multi-GPU component of nvFuser is expected to be disabled. +// +// Instead of adding `#ifdef NVFUSER_DISTRIBUTED` in too many places, this file +// provides a buildable mock implementation of c10d to keep nvFuser code less +// divergent. This implementation won't run because tests and user code are +// guarded by Communicator::is_available. + #pragma once #include @@ -170,6 +183,21 @@ struct TCPStoreOptions { static constexpr uint16_t kDefaultPort = 0; }; -class TCPStore : public torch::CustomClassHolder {}; +class TCPStore : public torch::CustomClassHolder { + public: + std::vector get(const std::string&) { + return {}; + } + + void set(const std::string&, const std::vector&) {} + + bool check(const std::vector&) { + return false; + } + + bool deleteKey(const std::string&) { + return false; + } +}; } // namespace c10d diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index c1cc3e31cfe..7dd08a87f0a 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -103,6 +103,10 @@ class MultiDeviceExecutor { return host_ir_executor_->getFusionExecutorCaches(); }; + auto* hostIrEvaluator() const { + return host_ir_executor_.get(); + } + private: // holds the Communicator to be used for execution Communicator& comm_; diff --git a/csrc/multidevice/ipc_handle.cpp b/csrc/multidevice/ipc_handle.cpp index dd96b5a72e8..9a5ec4286b8 100644 --- a/csrc/multidevice/ipc_handle.cpp +++ b/csrc/multidevice/ipc_handle.cpp @@ -95,7 +95,6 @@ std::string IpcHandleCache::getTcpStoreKey( void IpcHandleCache::exchangeHandles( const std::vector& communications) { -#ifdef NVFUSER_DISTRIBUTED Communicator* communicator = &Communicator::getInstance(); const int64_t my_rank = communicator->deviceId(); @@ -152,9 +151,12 @@ void IpcHandleCache::exchangeHandles( insert(communication, std::move(ipc_handles)); } -#else // NVFUSER_DISTRIBUTED - NVF_ERROR(false, "NVFUSER_DISTRIBUTED is not defined"); -#endif // NVFUSER_DISTRIBUTED + + // a second barrier is needed here to ensure all ranks have received the + // memhandles and the keys are deleted from the store before the next call to + // exchangeHandles + // TODO: precisely select what ranks need to wait on that barrier. + communicator->barrier(); } } // namespace nvfuser diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 475d852028b..bce292d1377 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -131,6 +131,13 @@ std::unordered_map mapDeviceParallelTypeToId( continue; } + // rDIDx{i0}, usually a product of an Allreduce or a ReduceScatter, is + // treated as replicated. This way `iDIDx{i0} => rDIDx{i0}` is considered + // resharding. + if (id->isReduction()) { + continue; + } + NVF_ERROR( parallel_type_to_id.try_emplace(parallel_type, id).second, "Found multiple loop IterDomains with the same parallel type (", @@ -564,16 +571,6 @@ bool haveDifferentShardings( return false; } - // iDIDx{i0} => rDIDx{i0} triggers an allreduce even though the two `i0`s - // are equivalent. - if (c_id->isReduction()) { - NVF_ERROR( - !p_id->isReduction(), - "Reduction IterDomains in the producer's logical shouldn't be mapped: ", - p_id); - return false; - } - return simplifyExpr( SimplifyingIrBuilder::eqExpr(p_index, c_index), /*variables=*/{}, @@ -633,14 +630,17 @@ bool isInnerResharding(Expr* expr) { return false; } -void shardAllLike(TensorView* ref, std::vector tvs) { +void shardAllLike( + TensorView* ref, + const std::vector& tvs, + const std::unordered_set& parallel_types) { + if (tvs.empty()) { + return; + } for (auto tv : tvs) { tv->setDeviceMesh(ref->getDeviceMesh()); } - if (!tvs.empty()) { - scheduler_utils::parallelizeAllLike( - ref, tvs, {ParallelType::DIDx, ParallelType::Serial}); - } + scheduler_utils::parallelizeAllLike(ref, tvs, parallel_types); } void shardBetween( diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index 34c510ccb2e..4134c943fac 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -57,8 +57,15 @@ bool haveDifferentShardings( // Returns whether a resharding expr reshards an inner axis bool isInnerResharding(Expr* expr); -// Shards all tensors in tvs like reference -void shardAllLike(TensorView* ref, std::vector tvs); +// Shards all tensors in tvs like reference. +// Accepts a set of parallel types to shard on. +// If empty, all DID parallel types are used. +void shardAllLike( + TensorView* ref, + const std::vector& tvs, + const std::unordered_set& parallel_types = { + kParallelTypeDIDs.begin(), + kParallelTypeDIDs.end()}); // Shards all TVs between from and to AND between TVs created inside a fusion // and to. This is required for (1) expressions like rng_uniform that create a diff --git a/csrc/ops/indexing.cpp b/csrc/ops/indexing.cpp index 5ff75065ff2..f05601fd6a6 100644 --- a/csrc/ops/indexing.cpp +++ b/csrc/ops/indexing.cpp @@ -99,6 +99,44 @@ TensorView* indexSelect( return out; } +// This is a restricted version of torch.index_put(..., accumulate=true) +TensorView* indexPutAccumulate( + TensorView* acc_tv, + TensorView* index_tv, + TensorView* value_tv) { + DataType dtype = acc_tv->getDataType().value(); + NVF_CHECK( + dtype != DataType::Null, "Invalid datatype provided for new value."); + + // broadcast index_tv if applicable + if (index_tv->nDims() == 1) { + index_tv = unsqueeze(index_tv, -1); + } + + std::vector acc_domain = + TensorDomain::noReductions(acc_tv->getLogicalDomain()); + std::vector index_domain = + TensorDomain::noReductions(index_tv->getLogicalDomain()); + std::vector value_domain = + TensorDomain::noReductions(value_tv->getLogicalDomain()); + + NVF_CHECK(acc_domain.size() == 2); + NVF_CHECK(index_domain.size() == 2); + NVF_CHECK(index_domain.at(1)->isBroadcast()); + NVF_CHECK(value_domain.size() == 2); + // IndexPutAccumulateOp semantics + // + // Producers: + // accumulate [ vocab, hidden ] + // broadcast_index [ seq, broadcast ] + // value [ seq, hidden ] + // Consumers: + // output [ vocab, hidden ] + TensorView* out = ops::newValLike(acc_tv, dtype)->as(); + IrBuilder::create(out, acc_tv, index_tv, value_tv); + return out; +} + // torch.gather TensorView* gather(TensorView* inp, int64_t dim, TensorView* index) { auto inp_domain = TensorDomain::noReductions(inp->getLogicalDomain()); diff --git a/csrc/ops/indexing.h b/csrc/ops/indexing.h index c8152c33f82..96eceb515b5 100644 --- a/csrc/ops/indexing.h +++ b/csrc/ops/indexing.h @@ -23,6 +23,12 @@ NVF_API TensorView* indexSelect( int64_t dim, TensorView* index); +// This is a restricted version of torch.index_put(..., accumulate=true) +TensorView* indexPutAccumulate( + TensorView* acc_tv, + TensorView* index_tv, + TensorView* value_tv); + // torch.gather NVF_API TensorView* gather(TensorView* input, int64_t dim, TensorView* index); diff --git a/csrc/options.cpp b/csrc/options.cpp index fadf8cef0bb..33610d5b8fc 100644 --- a/csrc/options.cpp +++ b/csrc/options.cpp @@ -139,6 +139,7 @@ std::unordered_map> Options< {"python_definition_segments", DebugDumpOption::PythonDefinitionSegments}, {"python_frontend_debug", DebugDumpOption::PythonFrontendDebug}, {"sass", DebugDumpOption::Sass}, + {"sass_to_file", DebugDumpOption::SassToFile}, {"segmented_fusion", DebugDumpOption::FusionSegments}, {"segmenter_logging", DebugDumpOption::FusionSegmenterLog}, {"scheduler_params", DebugDumpOption::SchedulerDebug}, @@ -258,40 +259,32 @@ std::unordered_map> Options< return options; } -namespace { - -// These may need to be thread local, or their modifications may need to -// be protected by mutual exclusion for thread safety. At this -// moment, the correctness of modifying option values has to be -// guaranteed by the modifying code. - -DebugDumpOptions active_dump_options; - -EnableOptions active_enable_options; - -DisableOptions active_disable_options; - -ProfilerOptions active_profiler_options; - -} // namespace - template <> Options& OptionsGuard::getCurOptions() { + // Note: Make options thread_local. + // We want the behavior that new threads would inherit options from the *base* + // threads. We need to figure out how to automatically do that before + // switching to thread_local. For now we are using mutex to guard option + // access, which is necessary to avoid data racing. + static DebugDumpOptions active_dump_options; return active_dump_options; } template <> Options& OptionsGuard::getCurOptions() { + static EnableOptions active_enable_options; return active_enable_options; } template <> Options& OptionsGuard::getCurOptions() { + static DisableOptions active_disable_options; return active_disable_options; } template <> Options& OptionsGuard::getCurOptions() { + static ProfilerOptions active_profiler_options; return active_profiler_options; } diff --git a/csrc/options.h b/csrc/options.h index 8e61d2d14d7..b050a0a0199 100644 --- a/csrc/options.h +++ b/csrc/options.h @@ -11,6 +11,7 @@ #include #include +#include #include #include #include @@ -70,7 +71,8 @@ enum class DebugDumpOption { TransformPropagator, //! When running TransformPropagator, print propagation //! path and replay result Cubin, //! Dump compiled CUBIN - Sass, // Dump disassembled SASS + Sass, //! Dump disassembled SASS + SassToFile, //!< Dump disassembled SASS to File Ptx, //! Dump compiled PTX BankConflictInfo, //! Dump bank confliction info SyncMap, //! RAW dependency info @@ -79,7 +81,7 @@ enum class DebugDumpOption { ExprSort, //! Print merging decisions on expression sorting ExprSortVerbose, //! Print verbose debug info on expression sorting LoopRotation, //! Print loop rotation log - Occupancy, // Dump occupancy + Occupancy, //! Dump occupancy IndexType, //! Print the index type of the launched kernel PredicateElimination, //! Print the predicate elimination information IndexingVerbose, //! Print verbose debug info on indexing @@ -179,16 +181,31 @@ class Options { public: Options() : options_(getOptionsFromEnv()) {} + Options(const Options& other) { + std::lock_guard lock_other(other.mutex_); + options_ = other.options_; + } + + Options& operator=(const Options& other) { + std::lock_guard lock_other(other.mutex_); + std::lock_guard lock(mutex_); + options_ = other.options_; + return *this; + } + bool has(OptionEnum option) const { + std::lock_guard lock(mutex_); return options_.count(option); } bool hasAny() const { + std::lock_guard lock(mutex_); return !options_.empty(); } const std::vector& getArgs(OptionEnum option) const { NVF_ERROR(has(option), "Option not set"); + std::lock_guard lock(mutex_); return options_.at(option); } @@ -201,10 +218,12 @@ class Options { } void set(OptionEnum option_type, std::vector option = {}) { + std::lock_guard lock(mutex_); options_[option_type] = option; } void unset(OptionEnum option_type) { + std::lock_guard lock(mutex_); options_.erase(option_type); } @@ -213,6 +232,7 @@ class Options { protected: std::unordered_map> options_; + mutable std::mutex mutex_; }; //! Utility class to temporarily overrride the Enable options, diff --git a/csrc/parallel_dimension_map.h b/csrc/parallel_dimension_map.h index 35a85f9383b..ceb5218ec24 100644 --- a/csrc/parallel_dimension_map.h +++ b/csrc/parallel_dimension_map.h @@ -55,7 +55,7 @@ class ParallelDimensionMap { //! for loading circular buffer tensors. Val* getRawLoad(ParallelType pt) const; - //! The padded val ensures that CTA has 128 threads for the LoadWarp. This + //! The padded val ensures that CTA has 128 threads for the AsyncWarp. This //! function returns the padded val for the warp specialized ParallelType. int64_t getWarpSpecializationPaddedVal(ParallelType pt) const; @@ -89,7 +89,7 @@ class ParallelDimensionMap { //! If we are doing warp specialization on pt, then we need to increase //! the parallel dimension size of pt by one, where the extra one is used - //! as the load warp. In this case, pt becomes non-exact. + //! as the async warp. In this case, pt becomes non-exact. void adjustMappingsForWarpSpecialization(); private: diff --git a/csrc/predicate_compute.cpp b/csrc/predicate_compute.cpp index 49b48fbea1e..86c8d362f4c 100644 --- a/csrc/predicate_compute.cpp +++ b/csrc/predicate_compute.cpp @@ -624,33 +624,33 @@ Val* createMultipleExpressionElectSync( const auto& pdim_map = GpuLower::current()->parallelDimensionMap(); // Determine if warp specialized tma load expression. - ParallelType load_warp_on = ParallelType::Serial; - auto load_warp_loop_it = + ParallelType async_warp_on = ParallelType::Serial; + auto async_warp_loop_it = std::find_if(loops.begin(), loops.end(), [](ForLoop* fl) { return fl->circularBufferLoopStage() == - CircularBufferLoopStage::LoadWarp; + CircularBufferLoopStage::AsyncWarp; }); bool is_register_sharing = false; - if (load_warp_loop_it != loops.end()) { + if (async_warp_loop_it != loops.end()) { auto circular_buffer_type = std::get( GpuLower::current() ->circularBufferInfo() - .getCircularBufferOptionsFor((*load_warp_loop_it)->iter_domain()) + .getCircularBufferOptionsFor((*async_warp_loop_it)->iter_domain()) .type); - load_warp_on = circular_buffer_type.on; + async_warp_on = circular_buffer_type.on; is_register_sharing = circular_buffer_type.num_registers.has_value(); } // Short-circuit: register sharing is not used, don't need to pad a full warp - // group. If we are in a load warp, then the warp-dispatching IfThenElse - // already selects on `load_warp_on`, so we should not generate + // group. If we are in a async warp, then the warp-dispatching IfThenElse + // already selects on `async_warp_on`, so we should not generate // predicates for it here. if (!is_register_sharing) { - Val* conditional = load_warp_on == ParallelType::TIDx + Val* conditional = async_warp_on == ParallelType::TIDx ? pred->fusion()->trueVal() : createElectSyncPredicate(); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { - if (pdim_map.has(pt) && load_warp_on != pt) { + if (pdim_map.has(pt) && async_warp_on != pt) { conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); @@ -661,20 +661,20 @@ Val* createMultipleExpressionElectSync( // If not specialized on TIDx, load branch has full size of bdimx, // we can use the first warp, otherwise should use the last warp. - bool use_first_warp = load_warp_on != ParallelType::TIDx; + bool use_first_warp = async_warp_on != ParallelType::TIDx; Val* conditional = createElectSyncPredicate(use_first_warp); for (auto pt : {ParallelType::TIDy, ParallelType::TIDz}) { if (!pdim_map.has(pt)) { continue; } - if (load_warp_on != pt) { + if (async_warp_on != pt) { // Not specialized on pt, use the first thread. conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr(NamedScalar::getParallelIndex(pt), zero)); } else { // Specialized on pt, use the last thread. - Val* raw = GpuLower::current()->parallelDimensionMap().get(load_warp_on); + Val* raw = GpuLower::current()->parallelDimensionMap().get(async_warp_on); conditional = SimplifyingIrBuilder::logicalAndExpr( conditional, IrBuilder::eqExpr( diff --git a/csrc/preseg_passes/insert_reshardings.cpp b/csrc/preseg_passes/insert_reshardings.cpp index eab98af208e..52adef04b30 100644 --- a/csrc/preseg_passes/insert_reshardings.cpp +++ b/csrc/preseg_passes/insert_reshardings.cpp @@ -29,6 +29,20 @@ bool shouldReshardAfter(Expr* expr) { return expr->inputs().size() == 1 && expr->outputs().size() == 1; } +std::unordered_set getParallelTypesForResharding() { + // Consider a reshard case: + // input [DIDx(i0), i1] -> op -> output [i0, DIDx(i1)] + // This is decomposed into: + // input [DIDx(i0), i1] -> op -> output [DIDx(i0), i1] -> set -> + // new_output [i0, DIDx(i1)] ParallelType::Serial is required here so the + // output is sharded as [DIDx(i0), i1] instead of [DIDx(i0), DIDx(i1)] + // when sharding using input as the reference. + std::unordered_set parallel_types{ + kParallelTypeDIDs.begin(), kParallelTypeDIDs.end()}; + parallel_types.insert(ParallelType::Serial); + return parallel_types; +} + void insertReshardingSetsBefore(Fusion* fusion) { // Remove this after we refactor this as a pre-segmenter pass. FusionGuard fg(fusion); @@ -70,7 +84,8 @@ void insertReshardingSetsBefore(Fusion* fusion) { new_inputs.push_back(new_input); expr = ir_utils::replaceValInExprInputs(expr, input, new_input); } - shardAllLike(output, new_inputs); + + shardAllLike(output, new_inputs, getParallelTypesForResharding()); } } @@ -110,7 +125,8 @@ void insertReshardingSetsAfter(Fusion* fusion) { // Update shardings new_output takes output's sharding, // output takes input's sharding shardAllLike(output, {new_output}); - shardAllLike(input, {output}); + + shardAllLike(input, {output}, getParallelTypesForResharding()); } } } @@ -148,6 +164,8 @@ void rFactorLoopSplits(Fusion* fusion) { std::vector rfactor_axes; rfactor_axes.reserve(tv->nDims()); + std::unordered_set reduced_parallel_types; + for (auto&& [i, loop_id] : enumerate(tv->getLoopDomain())) { if (!loop_id->isReduction()) { // rFactor only applies to reduction dimensions. @@ -162,14 +180,47 @@ void rFactorLoopSplits(Fusion* fusion) { continue; } - if (!loop_id->isParallelized()) { + const ParallelType parallel_type = loop_id->getParallelType(); + if (parallel_type == ParallelType::Serial) { // rFactor non-parallelized IDs so they get reduced locally. rfactor_axes.push_back(i); + } else { + reduced_parallel_types.insert(parallel_type); } } if (!rfactor_axes.empty()) { - tv->rFactor(rfactor_axes); + TensorView* local = tv->rFactor(rfactor_axes); + // Before rFactor: + // + // [i{m} i{n} r{k}] + // / \ / \. + // iDIDx{d} i{n/d} rDIDx{d} r{k/d} + // + // After rFactor: + // + // r{k} + // / \. + // [i{m} i{n} iDIDx{d} r{k/d}] + // / \. + // iDIDx{d} i{n/d} + // + // | + // | reduce + // v + // + // [i{m} i{n} rDIDx{d}] + // / \. + // iDIDx{d} i{n/d} + // + // The TensorView returned by rFactor has two iDIDx, which is disallowed. + // The following code unparallelizes the first iDIDx{d}. + for (IterDomain* loop_id : local->getLoopDomain()) { + if (!loop_id->isRFactorProduct() && + reduced_parallel_types.count(loop_id->getParallelType())) { + loop_id->parallelize(ParallelType::Serial); + } + } } } } diff --git a/csrc/preseg_passes/make_resharding_contiguous.cpp b/csrc/preseg_passes/make_resharding_contiguous.cpp index 04fbe0d7173..359f562011d 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.cpp +++ b/csrc/preseg_passes/make_resharding_contiguous.cpp @@ -12,35 +12,140 @@ #include #include #include +#include namespace nvfuser::preseg_passes { namespace { -void setShardedAllocationDomain(TensorView* tv) { - if (!tv->hasAllocation()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); + +// Validates meshes (i.e. all TensorViews have a device mesh or none) and +// returns true if any TensorView has a device mesh. +bool validateMeshes(Fusion* fusion) { + // Validate that meshes are assigned to all TensorViews or none. + bool tv_with_mesh_found = false; + bool tv_without_mesh_found = false; + + for (auto tv : fusion->allTvs()) { + if (tv->isCpuScalar()) { + continue; + } + tv->hasDeviceMesh() ? tv_with_mesh_found = true + : tv_without_mesh_found = true; } + NVF_CHECK( + !(tv_with_mesh_found && tv_without_mesh_found), + "Cannot have some TensorViews with device mesh and some without."); + return tv_with_mesh_found; } + +// Reorders the loop domain in the same relative order as the allocation domain. +// Specifically: +// 1. It uses the exprs between logical and loop domain to split the allocation +// domain +// 2. It reorders the loop domain to match the split allocation domain. +// 3. It computes the contiguity of the transformed allocation domain through +// the split exprs. +// 4. Sets the allocation domain to be the same as the loop domain with the +// computed contiguity. This preserves both the sharding and any stride order. +// Note: Ideally, the loop domain can follow the logical domain and the +// allocation domain can follow the stride order specified/inferred. However, we +// currently require loop domain to be the same as allocation domain. This +// behavior will be modified in the future with allocation and loop domain being +// propagated independently. +void setLoopAndAllocationDomain(TensorView* tv) { + auto alloc_dom = tv->getMaybeAllocationDomain(); + auto contiguity = tv->getContiguity(); + + auto splitContiguity = [](std::optional contiguity) + -> std::pair, std::optional> { + if (!contiguity.has_value()) { + return std::make_pair(std::nullopt, std::nullopt); + } + if (contiguity.value()) { + return std::make_pair(true, true); + } + return std::make_pair(true, false); + }; + + // Allocation domain should be a permutation of logical domain at this point. + std::vector transform_exprs = DependencyCheck::getAllExprsBetween( + {alloc_dom.begin(), alloc_dom.end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}); + + NVF_ERROR( + std::all_of( + transform_exprs.begin(), + transform_exprs.end(), + [](Expr* expr) { return expr->isA(); }), + "Expected all transform exprs to be a split between logical and loop domain during sharding propagation."); + + for (auto* expr : transform_exprs) { + Split* split = dynamic_cast(expr); + auto find_it = std::find(alloc_dom.begin(), alloc_dom.end(), split->in()); + NVF_ERROR( + find_it != alloc_dom.end(), + "Split input ", + split->in()->toString(), + " not found in given ids: ", + alloc_dom); + + auto pos = std::distance(alloc_dom.begin(), find_it); + auto [outer_contiguity, inner_contiguity] = + splitContiguity(contiguity.at(pos)); + + alloc_dom[pos] = split->inner(); + alloc_dom.insert(alloc_dom.begin() + pos, split->outer()); + + contiguity[pos] = inner_contiguity; + contiguity.insert(contiguity.begin() + pos, outer_contiguity); + } + + std::optional> permutation = + ir_utils::computePermutation(alloc_dom, tv->getLoopDomain()); + NVF_ERROR( + permutation.has_value(), + "Failed to find a valid permutation for reordering", + tv->getLoopDomain(), + " as ", + alloc_dom); + tv->reorder(permutation.value()); + tv->setAllocationDomain(tv->getLoopDomain(), contiguity); +} + +bool isTvContiguous(TensorView* tv) { + return std::all_of( + tv->getContiguity().begin(), + tv->getContiguity().end(), + [](const std::optional& c) { return c.value_or(true); }); +} + } // namespace void MakeReshardingContiguousPass::runPass(Fusion* fusion) { + bool has_mesh = validateMeshes(fusion); + if (!has_mesh) { + return; + } + for (Expr* expr : fusion->exprs()) { - if (!isResharding(expr)) { - continue; + auto inputs = ir_utils::filterByType(expr->inputs()); + auto outputs = ir_utils::filterByType(expr->outputs()); + + for (auto tv : inputs) { + setLoopAndAllocationDomain(tv); } - for (auto* tv : ir_utils::filterByType(expr->inputs())) { - for (auto c : tv->getContiguity()) { - if (c.has_value()) { - NVF_CHECK( - c.value(), - "Resharding expression input must be contiguous: ", - expr); - } - } - setShardedAllocationDomain(tv); + for (auto tv : outputs) { + setLoopAndAllocationDomain(tv); } - for (auto tv : ir_utils::filterByType(expr->outputs())) { - setShardedAllocationDomain(tv); + + if (isResharding(expr)) { + auto check_contiguity = [&](const auto& tvs) { + return std::all_of(tvs.begin(), tvs.end(), isTvContiguous); + }; + NVF_CHECK( + check_contiguity(inputs) && check_contiguity(outputs), + "Resharding expression must have contiguous inputs and outputs: ", + expr); } } } diff --git a/csrc/preseg_passes/make_resharding_contiguous.h b/csrc/preseg_passes/make_resharding_contiguous.h index 60ded24f76d..8a719683004 100644 --- a/csrc/preseg_passes/make_resharding_contiguous.h +++ b/csrc/preseg_passes/make_resharding_contiguous.h @@ -15,11 +15,18 @@ namespace nvfuser::preseg_passes { -// Resharding expressions are mapped to collective libraries which expect +// This pass: +// 1. Validates that all TensorViews have a device mesh or none. +// 2. Resharding expressions are mapped to collective libraries which expect // contiguous tensors and output contiguous buffers. This pass checks that -// inputs are contiguous and sets the allocation domain of inputs and outputs of -// all resharding expressions. This pass should run after all passes that add or -// update resharding expressions. +// inputs are contiguous. +// 3. Sets the allocation domain of all fusion tvs if they have a device mesh. +// The allocation domain is obtained by transforming the `maybeAllocationDomain` +// using the transforms to loop domain. This ensures that the allocation domain +// has DID loop splits. All iterdomains derived from a given logical iterdomain +// are placed together. See `setLoopAndAllocationDomain` for more details. +// Eventually, this pass should run after `markAliasesPrepare` and +// `AllocationDomainPass` after they are fixed. class MakeReshardingContiguousPass : public OptimizationPass { friend class OptimizationPass; diff --git a/csrc/preseg_passes/optimization_pass.h b/csrc/preseg_passes/optimization_pass.h index 53d8a8acd3c..359a4a42742 100644 --- a/csrc/preseg_passes/optimization_pass.h +++ b/csrc/preseg_passes/optimization_pass.h @@ -18,8 +18,6 @@ namespace nvfuser::preseg_passes { -using FusionPass = std::function; - //! [experimental API] //! Base class to unify optimization pass APIs. //! OptimizationPass can be turned on/off programmatically with the `setEnabled` diff --git a/csrc/preseg_passes/pre_segmenter.cpp b/csrc/preseg_passes/pre_segmenter.cpp index 042f03191f7..52d949ad95a 100644 --- a/csrc/preseg_passes/pre_segmenter.cpp +++ b/csrc/preseg_passes/pre_segmenter.cpp @@ -39,12 +39,6 @@ namespace nvfuser::preseg_passes { debug() << "========================================" << std::endl; } - // For resharding across GPUs. - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - OptimizationPass::runPass(fusion); - // Replace TensorViews with zero extent. Outputs and inputs may still be empty OptimizationPass::runPass(fusion); // This pass should be placed before ConsecutiveCastPass as more @@ -81,6 +75,16 @@ namespace nvfuser::preseg_passes { OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); + + // All the multidevice passes are moved after allocation related passes: + // MarkAliasesPreparePass, and AllocationDomainPass Multidevice passes will + // try to set the allocation domain for tvs with device mesh which will + // conflict with these passes. + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); + OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); OptimizationPass::runPass(fusion); diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 69ba5983060..b3f7344a8e7 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -13,117 +13,296 @@ #include #include #include +#include +#include namespace nvfuser::preseg_passes { namespace { -void validateMeshes(Fusion* fusion) { - // Validate that meshes are assigned to all TensorViews or none. - TensorView* tv_with_mesh = nullptr; - TensorView* tv_without_mesh = nullptr; - for (TensorView* tv : fusion->allTvs()) { - auto update_if_null = [](TensorView*& lhs, TensorView* rhs) { - if (lhs == nullptr) { - lhs = rhs; - } - }; - if (tv->isCpuScalar()) { - continue; +template +std::vector filterTvsWithMesh(const Range& tvs) { + std::vector tvs_with_mesh; + std::copy_if( + tvs.begin(), + tvs.end(), + std::back_inserter(tvs_with_mesh), + [](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); }); + return tvs_with_mesh; +} + +int64_t numDeviceDims(TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + std::mem_fn(&IterDomain::isDeviceDim)); +} + +// Sort the given tvs by the number of device dimensions in descending order. +// Break ties by the total number of dimensions. +// Only includes TensorViews that have a device mesh. +template +std::vector sortTvsByDeviceDims(const Range& tvs) { + // Filter out TVs without a device mesh + std::vector tvs_with_mesh = filterTvsWithMesh(tvs); + + // Then sort the filtered TVs + std::stable_sort( + tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { + int64_t a_device_dims = numDeviceDims(a); + int64_t b_device_dims = numDeviceDims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims > b_device_dims; + } + // Break ties by the total number of dimensions + return a->nDims() > b->nDims(); + }); + + return tvs_with_mesh; +} + +// Order the inputs of the expression based on their priority. +// For linear op, we use weights and bias before input. +// For matmul op, we use weights before input. +// For other ops, we sort the inputs by the number of device dimensions in +// descending order. +std::vector getOrderedReferenceInputs(Expr* expr) { + const auto& inputs = ir_utils::filterByType(expr->inputs()); + if (LinearOp* linear_op = dynamic_cast(expr)) { + // Use weights and bias before input. + return filterTvsWithMesh(std::vector( + {linear_op->inB(), linear_op->bias(), linear_op->inA()})); + } + + if (MatmulOp* matmul_op = dynamic_cast(expr)) { + // Use weights before input. + return filterTvsWithMesh( + std::vector({matmul_op->inB(), matmul_op->inA()})); + } + + // Sort inputs by number of device dimensions in descending order + std::vector sorted_inputs = sortTvsByDeviceDims(inputs); + + return sorted_inputs; +} + +std::vector getOutputsWithoutMesh(Expr* expr) { + const auto& outputs = ir_utils::filterByType(expr->outputs()); + std::vector outputs_without_mesh; + std::copy_if( + outputs.begin(), + outputs.end(), + std::back_inserter(outputs_without_mesh), + [](TensorView* tv) { return !tv->hasDeviceMesh(); }); + return outputs_without_mesh; +} + +// Custom selector to specify direction of transform propagation. +class PropagateShardingsSelector : public SetSelector { + private: + bool allow_c2p_; + bool allow_p2c_; + + public: + explicit PropagateShardingsSelector( + const std::unordered_set& selected_tvs, + bool allow_c2p = true, + bool allow_p2c = true) + : SetSelector(selected_tvs), + allow_c2p_(allow_c2p), + allow_p2c_(allow_p2c) {} + + bool allowC2P(TensorView* from, TensorView* to) override { + return allow_c2p_ && SetSelector::allowC2P(from, to); + } + + bool allowP2C(TensorView* from, TensorView* to) override { + return allow_p2c_ && SetSelector::allowP2C(from, to); + } +}; + +// Reorder the DID axis with the given parallel types to the front. +// Returns the number of device dimensions that were reordered to the front. +// This allows us to limit propagation to only the relevant DID axis. +int64_t selectiveReorderDIDToFront( + TensorView* tv, + const std::unordered_set& selected_parallel_types) { + std::unordered_map old2new; + int64_t current_pos = 0; + + for (auto&& [pos, id] : enumerate(tv->getLoopDomain())) { + if (id->isDeviceDim() && + selected_parallel_types.count(id->getParallelType())) { + old2new[pos] = current_pos; + current_pos++; } + } - if (tv->hasDeviceMesh()) { - update_if_null(tv_with_mesh, tv); - } else { - update_if_null(tv_without_mesh, tv); + tv->reorder(old2new); + return current_pos; +} + +// Returns the set of parallel types seen on the loop domain of the given tvs. +std::unordered_set getParallelTypesToPropagate( + std::vector tvs) { + // Get the set of parallel types seen on the loop domain of the given tvs. + std::unordered_set existing_parallel_types; + for (auto tv : tvs) { + for (auto id : tv->getLoopDomain()) { + if (id->isDeviceDim()) { + existing_parallel_types.insert(id->getParallelType()); + } + } + } + std::unordered_set selected_parallel_types; + for (ParallelType pt : kParallelTypeDIDs) { + if (!existing_parallel_types.count(pt)) { + selected_parallel_types.insert(pt); } } - NVF_CHECK( - tv_with_mesh == nullptr || tv_without_mesh == nullptr, - "Found ", - tv_with_mesh, - " assigned a mesh and ", - tv_without_mesh, - " not."); + return selected_parallel_types; +} + +void propagateDIDTransform( + TensorView* ref, + std::vector tvs, + int64_t did_pos, + bool allow_c2p, + bool allow_p2c) { + TransformPropagator propagator(ref, did_pos); + PropagateShardingsSelector selector( + {tvs.begin(), tvs.end()}, allow_c2p, allow_p2c); + MaxLogicalDomainInfoSpanningTree(ref, &selector).traverse(&propagator); } + } // namespace +// This presegmentation pass propagates shardings from fusion inputs to +// downstream tensorviews. +// 1. Forward propagating DID loop splits and parallelization from inputs to +// outputs that don't have a mesh using TransformPropagator +// 2. Reshape is handled manually since the DID loop split transforms conflict +// with the reshape root-to-logical transforms if using TransformPropagator +// 3. Back-propagating device meshes to ensure all TensorViews have consistent +// meshes. This also splits and parallelizes unsharded inputs based on outputs. +// See `MultiDevicePresegPassesTest.ResidualAdd` for an example. +// 4. Reorders the loop domain as the allocation order. Ideally, loop domain +// should follow logical domain and allocation domain should follow any stride +// order specified/inferred. However, we currently require loop domain to be the +// same as allocation domain. void PropagateShardingsPass::runPass(Fusion* fusion) { - auto num_device_parallel_dimensions = [](const TensorView* tv) -> int64_t { - return std::count_if( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - std::mem_fn(&IterDomain::isDeviceDim)); - }; - const std::vector& exprs = fusion->exprs(); + for (Expr* expr : exprs) { - const auto& inputs = ir_utils::filterByType(expr->inputs()); - // Pick the "most parallel" input tensor as the reference. This is useful - // for propagating tensor parallelism from weights to MLP's intermediate - // tensors. For example, - // - // x: [b, s, h]; replicated. - // w0: [h, 4*h]; column-wise sharded. - // w1: [4*h, h]; row-wise sharded. - // y = matmul(x, w0) - // z = matmul(y, w1) - // - // With the above heuristic, `y` can be automatically sharded column-wise. - TensorView* ref_input = nullptr; - auto max_num_dids = std::numeric_limits::min(); - for (auto* input : inputs) { - if (!input->hasDeviceMesh()) { - continue; - } - int64_t num_dids = num_device_parallel_dimensions(input); - if (num_dids > max_num_dids) { - max_num_dids = num_dids; - ref_input = input; - } + // Note: Tvs without a mesh are assumed to have no manual sharding + // annotation and are sharded like the first producer Tv. + const auto& outputs_without_mesh = getOutputsWithoutMesh(expr); + if (outputs_without_mesh.empty()) { + continue; } - if (ref_input == nullptr) { + + const auto& reference_inputs = getOrderedReferenceInputs(expr); + + if (reference_inputs.empty()) { continue; } + // Propagate shardings from reference inputs in order. + for (auto* ref_input : reference_inputs) { + // Skip if the input has no device mesh or is nullptr. + NVF_ERROR( + ref_input != nullptr && ref_input->hasDeviceMesh(), + "Reference input ", + ref_input, + " has no device mesh."); - // Note: Tvs without a mesh are assumed to have no manual sharding - // annotation and are sharded like the first producer Tv. - const auto& outputs = ir_utils::filterByType(expr->outputs()); - std::vector outputs_without_mesh; - for (auto* tv : outputs) { - if (!tv->hasDeviceMesh()) { - outputs_without_mesh.push_back(tv); - } + // Reorder the DID axis to the front only if it does not have a parallel + // type already seen on the outputs. This avoids propagating the same + // parallel type on multiple axis of the output when using multiple + // reference inputs. Consider out [M, N] = linear (inp [M, K], weight (N, + // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N + // ([DIDy(d), N/d, K]). We propagate from weights first, so the output + // will be [M, DIDx(d), N/d]. When we propagate from inp next, we should + // not propagate DIDx parallel type to the output. Otherwise, the output + // will have multiple DIDx shardings which is invalid. + std::unordered_set selected_parallel_types = + getParallelTypesToPropagate(outputs_without_mesh); + + // This restricts the transform propagation to only the relevant DID axis. + int64_t did_pos = + selectiveReorderDIDToFront(ref_input, selected_parallel_types); + + // Propagate the DID loop split to the outputs without mesh. + propagateDIDTransform( + /*ref=*/ref_input, + /*tvs=*/outputs_without_mesh, + /*did_pos=*/did_pos, + /*allow_c2p=*/false, + /*allow_p2c=*/true); + + // Apply parallelization on the outputs without mesh. + shardAllLike(ref_input, outputs_without_mesh, selected_parallel_types); } - shardAllLike(ref_input, outputs_without_mesh); } // Back-propagate device meshes. This makes sure all TensorViews have a mesh // if any of them has one. This is needed in addition to the forward // propagation for ops that don't take any TensorView operands, e.g., // `uniform` used in dropout. See MultiDeviceTest.BackpropMeshes for an - // example. - for (auto i_expr = exprs.rbegin(); i_expr != exprs.rend(); i_expr++) { - Expr* expr = *i_expr; + // example. For non-fusion inputs, we also propagate shardings from outputs to + // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. + for (Expr* expr : exprs | std::views::reverse) { const auto& outputs = ir_utils::filterByType(expr->outputs()); - auto i_output = std::find_if( - outputs.begin(), - outputs.end(), - std::mem_fn(&TensorView::hasDeviceMesh)); - if (i_output == outputs.end()) { + // All outputs of an expression (Welford, SDPA) should be uniformly sharded. + // We pick the most parallel output as the reference. + // This is to avoid picking seed/offset tvs in SDPA. + std::vector sorted_outputs = sortTvsByDeviceDims(outputs); + + if (sorted_outputs.empty()) { + // No output with a device mesh. continue; } - TensorView* output_with_mesh = *i_output; + TensorView* ref_output = sorted_outputs.front(); + NVF_ERROR( + ref_output != nullptr && ref_output->hasDeviceMesh(), + "Reference output ", + ref_output, + " has no device mesh."); + + // For fusion inputs, only check if they have a device mesh. We do not + // modify their sharding. For non-fusion inputs, we try to propagate + // shardings from the reference output for parallel types that are not + // already present. const auto& inputs = ir_utils::filterByType(expr->inputs()); + std::vector sharding_candidates; for (auto* tv : inputs) { - if (!tv->hasDeviceMesh()) { - tv->setDeviceMesh(output_with_mesh->getDeviceMesh()); + if (tv->isFusionInput()) { + if (!tv->hasDeviceMesh()) { + tv->setDeviceMesh(ref_output->getDeviceMesh()); + } + continue; + } + if (!tv->hasDeviceMesh() || numDeviceDims(tv) == 0) { + sharding_candidates.push_back(tv); } } - } - validateMeshes(fusion); + if (sharding_candidates.empty()) { + continue; + } + + int64_t did_pos = selectiveReorderDIDToFront(ref_output, {}); + // Note: We do not have to manually shard for reshape here. + // TransformPropagator can handle reshapes when going from consumer to + // producer. + propagateDIDTransform( + /*ref=*/ref_output, + /*tvs=*/sharding_candidates, + /*did_pos=*/did_pos, + /*allow_c2p=*/true, + /*allow_p2c=*/false); + shardAllLike(ref_output, sharding_candidates); + } } } // namespace nvfuser::preseg_passes diff --git a/csrc/runtime/compiled_kernel.cpp b/csrc/runtime/compiled_kernel.cpp index 9679987881a..9defcbaab9c 100644 --- a/csrc/runtime/compiled_kernel.cpp +++ b/csrc/runtime/compiled_kernel.cpp @@ -719,6 +719,13 @@ std::unique_ptr compileSource( compiled_kernel->cubin_filename = dumpCompiledCodeToFile(compiled_kernel->cubin, func_name, ".cubin"); } + if (isDebugDumpEnabled(DebugDumpOption::SassToFile)) { + std::string sass_str = + disassembleBinary(compiled_kernel->cubin, "-fun 1 -c"); + compiled_kernel->sass = {sass_str.begin(), sass_str.end()}; + compiled_kernel->sass_filename = + dumpCompiledCodeToFile(compiled_kernel->sass, func_name, ".sass"); + } } if (!compile_to_sass || isDebugDumpEnabled(DebugDumpOption::Ptx)) { diff --git a/csrc/runtime/executor_utils.h b/csrc/runtime/executor_utils.h index 9b04b82e85d..843f4ead896 100644 --- a/csrc/runtime/executor_utils.h +++ b/csrc/runtime/executor_utils.h @@ -43,6 +43,8 @@ struct CudaExecutable : public NonCopyable { std::string cubin_filename; std::string kernel_name; std::string compile_args; + std::vector sass; + std::string sass_filename; long block_size = -1; int register_spills = -1; }; diff --git a/csrc/runtime/fusion_kernel_runtime.cpp b/csrc/runtime/fusion_kernel_runtime.cpp index c57abe4fcca..4afe4279a85 100644 --- a/csrc/runtime/fusion_kernel_runtime.cpp +++ b/csrc/runtime/fusion_kernel_runtime.cpp @@ -199,56 +199,6 @@ flatbuffers::Offset FusionKernelRuntime::serialize( segmented_fusion_fb); } -namespace { -std::vector toposortExprs( - SegmentedFusion* fusion, - SegmentedGroup* group) { - const std::vector& exprs = group->exprs(); - std::vector exprs_to_print(exprs.begin(), exprs.end()); - std::unordered_set exprs_to_print_set(exprs.begin(), exprs.end()); - std::unordered_set exprs_visited; - std::vector sorted_list; - while (!std::all_of( - exprs_to_print.begin(), - exprs_to_print.end(), - [&exprs_visited](auto expr) { return exprs_visited.count(expr); })) { - bool expr_added_to_sorted_list = false; - for (auto expr : exprs_to_print) { - if (!exprs_visited.count(expr)) { - bool add_this_expr = true; - // Check if any of the inputs of current - // expression within the group - // hasn't been visited - for (auto input : expr->inputs()) { - if (input->definition() && - exprs_to_print_set.count(input->definition()) && - !exprs_visited.count(input->definition())) { - add_this_expr = false; - break; - } - } - - // Append the current group to sorted list - // and mark visited - if (add_this_expr) { - expr_added_to_sorted_list = true; - exprs_visited.insert(expr); - sorted_list.push_back(expr); - break; - } - } - } - NVF_ERROR( - expr_added_to_sorted_list, - "group debug print failed, exprs within given vector not a DAG"); - } - NVF_CHECK( - sorted_list.size() == group->exprs().size(), - "Exprs should not have been lost during toposortExprs"); - return sorted_list; -} -} // namespace - void FusionKernelRuntime::deserialize( const serde::FusionKernelRuntime* buffer, int8_t device_index) { @@ -540,8 +490,7 @@ void FusionKernelRuntime::compileFusionParallel(KernelArgumentHolder args) { } else { // push back segment's exprs into the container as top level // expressions - for (auto* expr : - toposortExprs(segmented_fusion_.get(), group_to_run)) { + for (auto* expr : group_to_run->stablyOrderedExprs()) { auto cloned_expr = ir_cloner.clone(expr); hic->pushBackTopLevelExprs(cloned_expr); } diff --git a/csrc/scheduler/expr_eval_sched.cpp b/csrc/scheduler/expr_eval_sched.cpp index 684beae07fa..5a35c7884a7 100644 --- a/csrc/scheduler/expr_eval_sched.cpp +++ b/csrc/scheduler/expr_eval_sched.cpp @@ -51,7 +51,13 @@ bool ExprEvalScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } - if (exprs.front()->isOneOf()) { + // TODO: remove IndexPutAccumulateOp + if (exprs.front() + ->isOneOf< + SdpaFwdOp, + SdpaBwdOp, + EmbeddingFwdOp, + IndexPutAccumulateOp>()) { return true; } diff --git a/csrc/scheduler/hopper_multi_matmul.cpp b/csrc/scheduler/hopper_multi_matmul.cpp index ec3b509ba32..b24c353c37d 100644 --- a/csrc/scheduler/hopper_multi_matmul.cpp +++ b/csrc/scheduler/hopper_multi_matmul.cpp @@ -237,41 +237,6 @@ void HopperMultipleMatmulScheduler::reorderBlockTileTraversal( } } -TensorView* HopperMultipleMatmulScheduler::cacheAfter( - TensorView* orig, - LoadStoreOpType op_type, - CacheOp cache_op, - bool propagate_allocation_domain) { - const std::vector orig_alloc = orig->getMaybeAllocationDomain(); - - TensorView* c = - orig->cacheAfter(op_type, cache_op, propagate_allocation_domain); - - if (propagate_allocation_domain) { - const std::vector cache_alloc = c->getMaybeAllocationDomain(); - NVF_ERROR(orig_alloc.size() == cache_alloc.size()); - for (size_t i : arange(orig_alloc.size())) { - ValGroup vg = graph_->toGroup(orig_alloc[i]); - graph_->initializeVal(cache_alloc[i], vg); - } - } - - const std::vector orig_logical = - TensorDomain::noReductions(orig->getLogicalDomain()); - const std::vector cache_logical = c->getLogicalDomain(); - // in split-K we do rFactor which gives us a full = sum(partial) - // where partial has root domain that matches the logical domain of the - // original tensor. The logical domain contains Iteration transforms of the - // Reduction axis in the original mma output. - NVF_ERROR(orig_logical.size() == cache_logical.size()); - for (size_t i : arange(orig_logical.size())) { - ValGroup vg = graph_->toGroup(orig_logical[i]); - graph_->initializeVal(cache_logical[i], vg); - } - - return c; -} - std::vector> HopperMultipleMatmulScheduler:: blockTileTensors(const std::vector& tvs) { if (canonical_dim_ordering_.empty()) { @@ -623,19 +588,19 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { NVF_ERROR(d->definition() && d->definition()->isA()); TensorView* dc = d->definition()->input(0)->as(); - // NOTE: cacheBefore does not work with blockTileTensors - // cacheInputsAndOutputs creates a cache_before for each output. - // Apply cacheAfter to the existing cache tensor for output. // The chain of operations storing data to global memory: // registers -> (stmatrix) -> smem -> (tma_store) -> gmem - TensorView* d_smem = cacheAfter(dc, LoadStoreOpType::Set); + TensorView* d_smem = cacheBefore(d, LoadStoreOpType::Set); std::vector tvs_to_schedule{d, d_smem}; - bool dc_in_mma_results = + bool dc_is_mma_result = std::find(mma_results_.begin(), mma_results_.end(), dc) != mma_results_.end(); + bool dc_is_splitk_sum = params_->splitk_factor > 1 && + std::find(splitk_sums_.begin(), splitk_sums_.end(), dc) != + splitk_sums_.end(); - if (!dc_in_mma_results) { + if (!dc_is_mma_result && !dc_is_splitk_sum) { // Skip scheduling dc if it is an mma_result. This can happen if we are // not casting back to half-precision in the output tvs_to_schedule.push_back(dc); @@ -666,7 +631,7 @@ void HopperMultipleMatmulScheduler::scheduleEpilogue() { // Should not propagate if the dc is a mma output as the mma output has // already been scheduled. - if (!dc_in_mma_results) { + if (!dc_is_mma_result && !dc_is_splitk_sum) { auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( dc->getLoopDomain()); dc->setLoopDomain(s.as()); @@ -790,12 +755,12 @@ void HopperMultipleMatmulScheduler::setUpCircularBuffering() { // register properly in that case. cb_type = (CircularBufferType)WarpSpecialized(ParallelType::TIDy); } else { - constexpr int64_t num_registers_load_warp = 40; + constexpr int64_t num_registers_async_warp = 40; constexpr int64_t num_registers_compute_warp = 232; cb_type = (CircularBufferType)WarpSpecialized( ParallelType::TIDy, std::make_pair( - num_registers_load_warp, num_registers_compute_warp)); + num_registers_async_warp, num_registers_compute_warp)); } break; } diff --git a/csrc/scheduler/hopper_multi_matmul.h b/csrc/scheduler/hopper_multi_matmul.h index 854d0705234..a46d046f2e9 100644 --- a/csrc/scheduler/hopper_multi_matmul.h +++ b/csrc/scheduler/hopper_multi_matmul.h @@ -124,14 +124,6 @@ class HopperMultipleMatmulScheduler : public MultipleMatmulScheduler { TensorView* tv, std::vector& outer_dim_roles); - //! This calls orig->cacheAfter() and also updates the broadcast graph to - //! reflect the new IterDomain mappings - TensorView* cacheAfter( - TensorView* orig, - LoadStoreOpType op_type = LoadStoreOpType::Set, - CacheOp cache_op = CacheOp::AllLevels, - bool propagate_allocation_domain = false); - //! Do block tiling for a collection of TensorViews. The tensors should be //! unscheduled before this method is called. //! 1) Axes will be ordered according to canonicalDimOrdering, and then axes diff --git a/csrc/scheduler/matmul_utils.cpp b/csrc/scheduler/matmul_utils.cpp index e65d571ed76..98c0851472e 100644 --- a/csrc/scheduler/matmul_utils.cpp +++ b/csrc/scheduler/matmul_utils.cpp @@ -396,7 +396,7 @@ bool fillDefaultHopperHeuristic( // the _other_ dimension to create a new inner dimension. We find the swizzle // factor that is largest and has the least quantization when we divide that // other dimension by the swizzle factor. - int64_t swizzled_tiles = Mtiles <= Ntiles ? Ntiles : Mtiles; + int64_t swizzled_tiles = Mtiles >= Ntiles ? Ntiles : Mtiles; mparams->cta_order = Mtiles <= Ntiles ? MatmulParams::TileRasterizationOrder::ColumnMajor : MatmulParams::TileRasterizationOrder::RowMajor; diff --git a/csrc/scheduler/multi_matmul.cpp b/csrc/scheduler/multi_matmul.cpp index 377629d99e6..1976bdf6bef 100644 --- a/csrc/scheduler/multi_matmul.cpp +++ b/csrc/scheduler/multi_matmul.cpp @@ -227,4 +227,58 @@ void MultipleMatmulScheduler::cacheInputsAndOutputs(bool skip_intermediates) { } } +TensorView* MultipleMatmulScheduler::cacheBefore( + TensorView* orig, + LoadStoreOpType op_type) { + TensorView* c = orig->cacheBefore(op_type); + + const std::vector orig_logical = + TensorDomain::noReductions(orig->getLogicalDomain()); + const std::vector cache_logical = c->getLogicalDomain(); + NVF_ERROR(orig_logical.size() == cache_logical.size()); + for (size_t i : arange(orig_logical.size())) { + // The domain of orig gets transferred to c and a new domain is applied to + // orig + ValGroup vg = graph_->toGroup(cache_logical[i]); + graph_->initializeVal(orig_logical[i], vg); + } + + return c; +} + +TensorView* MultipleMatmulScheduler::cacheAfter( + TensorView* orig, + LoadStoreOpType op_type, + CacheOp cache_op, + bool propagate_allocation_domain) { + const std::vector orig_alloc = orig->getMaybeAllocationDomain(); + + TensorView* c = + orig->cacheAfter(op_type, cache_op, propagate_allocation_domain); + + if (propagate_allocation_domain) { + const std::vector cache_alloc = c->getMaybeAllocationDomain(); + NVF_ERROR(orig_alloc.size() == cache_alloc.size()); + for (size_t i : arange(orig_alloc.size())) { + ValGroup vg = graph_->toGroup(orig_alloc[i]); + graph_->initializeVal(cache_alloc[i], vg); + } + } + + const std::vector orig_logical = + TensorDomain::noReductions(orig->getLogicalDomain()); + const std::vector cache_logical = c->getLogicalDomain(); + // in split-K we do rFactor which gives us a full = sum(partial) + // where partial has root domain that matches the logical domain of the + // original tensor. The logical domain contains Iteration transforms of the + // Reduction axis in the original mma output. + NVF_ERROR(orig_logical.size() == cache_logical.size()); + for (size_t i : arange(orig_logical.size())) { + ValGroup vg = graph_->toGroup(orig_logical[i]); + graph_->initializeVal(cache_logical[i], vg); + } + + return c; +} + } // namespace nvfuser diff --git a/csrc/scheduler/multi_matmul.h b/csrc/scheduler/multi_matmul.h index 7bcb86f0ead..8f9d200bba7 100644 --- a/csrc/scheduler/multi_matmul.h +++ b/csrc/scheduler/multi_matmul.h @@ -60,6 +60,20 @@ class MultipleMatmulScheduler { TensorView* operand, int64_t vec_size) = 0; + //! This calls orig->cacheBefore() and also updates the broadcast graph to + //! reflect the new IterDomain mappings + TensorView* cacheBefore( + TensorView* orig, + LoadStoreOpType op_type = LoadStoreOpType::Set); + + //! This calls orig->cacheAfter() and also updates the broadcast graph to + //! reflect the new IterDomain mappings + TensorView* cacheAfter( + TensorView* orig, + LoadStoreOpType op_type = LoadStoreOpType::Set, + CacheOp cache_op = CacheOp::AllLevels, + bool propagate_allocation_domain = false); + protected: Fusion* fusion_; const MatmulParams* params_; diff --git a/csrc/scheduler/normalization_inner_outer.cpp b/csrc/scheduler/normalization_inner_outer.cpp index fcb7260501d..8264f046382 100644 --- a/csrc/scheduler/normalization_inner_outer.cpp +++ b/csrc/scheduler/normalization_inner_outer.cpp @@ -7,1064 +7,17 @@ // clang-format on #include #include -#include +#include +#include +#include #include -#include #include #include -#include -#include #include namespace nvfuser { namespace { - -// The roundup is due to the fact that the shared memory buffer is allocated -// as: ceilDiv(dim_size / vectorize_factor, threads_per_block). -// Let after_vect = dim_size / vectorize_factor; -// n_batch = ceilDiv(after_vect, threads_per_block); -// Then the shared memory buffer size is n_batch * vectorize_factor * -// threads_per_block * data_type_size. This function returns the maximum -// possible shared memory buffer size considering all possible block sizes. -int64_t roundUpSharedMemory( - int64_t tv_buffer_size, - int64_t data_type_size, - int64_t vectorize_factor, - int64_t threads_per_block_min, - int64_t threads_per_block_max, - int64_t threads_per_block_step) { - int64_t dim_size = tv_buffer_size / data_type_size; - int64_t after_vect = dim_size / vectorize_factor; - int64_t max_smem = 0; - for (int64_t threads_per_block = threads_per_block_min; - threads_per_block <= threads_per_block_max; - threads_per_block += threads_per_block_step) { - int64_t n_batch = ceilDiv(after_vect, threads_per_block); - max_smem = std::max( - max_smem, - n_batch * vectorize_factor * threads_per_block * data_type_size); - } - return max_smem; -} - -// Return the broadcast tvs that are broadcast to the iteration dimensions of -// the inner reduction tv. These tvs are reused in the loop over the iteration -// dimension. This reuse reduced the number loads from gmem and this tensor -// is likely the first candidate to be moved to shared memory when the register -// space runs low. -std::vector getOuterBroadcastTvs( - Fusion* fusion, - const std::vector& reduction_tvs) { - // set reference broadcast mask using the first inner reduction tv - std::vector ref_broadcast_mask; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - const auto& logical = tv->getLogicalDomain(); - ref_broadcast_mask.reserve(logical.size()); - for (const auto i : arange(logical.size())) { - ref_broadcast_mask.push_back(!logical.at(i)->isReduction()); - } - break; - } - } - NVF_ERROR(!ref_broadcast_mask.empty(), "ref_broadcast_mask is empty!"); - - // find the broadcast tensor whose broadcast mask is same to the reference - std::vector outer_broadcast_tvs; - for (auto tv : fusion->allTvs()) { - if (std::any_of( - tv->getLoopDomain().begin(), - tv->getLoopDomain().end(), - [](IterDomain* id) { return id->isBroadcast(); })) { - if (auto bcast = dynamic_cast(tv->definition())) { - if (bcast->getBroadcastDimFlags() == ref_broadcast_mask) { - outer_broadcast_tvs.emplace_back(tv); - } - } - } - } - return outer_broadcast_tvs; -} - -// Size of buffers storing intermediate outer reduction results -// TODO: check if we can directly start with [buffer_size = 1] -int64_t partialOuterReductionBufferSize( - const std::vector& reduction_tvs, - SchedulerRuntimeInfo& runtime_info) { - int64_t partial_reduction_buffer_size = 0; - for (auto buffer : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(buffer)) { - continue; - } - int64_t buffer_size = -1; - for (auto id : buffer->getLogicalDomain()) { - if (id->isReduction() || id->isBroadcast()) { - continue; - } - auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent()); - NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size."); - if (buffer_size == -1) { - buffer_size = id_size.as(); - } else { - buffer_size *= id_size.as(); - } - } - buffer_size = (buffer_size == -1) ? 0 - : buffer_size * - (int64_t)dataTypeSize(buffer->getDataType().value(), - runtime_info.getIndexType()); - partial_reduction_buffer_size += buffer_size; - } - return partial_reduction_buffer_size; -} - -// Decide where to store persistent buffers. -// By default, they reside in registers. -// If register space runs low but there's ample shared memory, -// move one or more buffers to shared memory until the register space is -// sufficient. -struct PersistentBufferStorageParams { - // representing buffers that are stored in shared memory, other buffers are - // stored in registers. - std::vector smem_persistent_buffers; - - // Total number of bytes occupied by all persistent buffers stored in shared - // memory. - int64_t smem_buffer_size = -1; - - // Total number of bytes occupied by all persistent buffers stored in - // registers. - int64_t regs_buffer_size = -1; - - // Additional shared memory usage per block that is not associated with - // persistent buffers. This includes memory for driver overhead and workspace - // for reductions. - int64_t smem_overhead = -1; - - // Flag indicating whether there are sufficient registers and shared memory - // available to accommodate all persistent buffers as required for efficient - // execution. - bool has_enough_regs_and_smem = false; - - // Flag indicating whether the persistent buffers are recomputed using inputs. - bool project_to_input = false; -}; - -// Prioritize keeping buffers used by outer broadcast tensors to shared memory -// because: -// (1) They are reused in every iteration of the outer loop, has lower IO. -// (2) Load occurs before the outer loop. Temporary register usage won't -// increase register pressure since the loop is the high-pressure region. -std::vector sortProjectableBufferInputs( - const std::vector& projectable_buffer_inputs, - const std::vector& outer_broadcast_tvs) { - // mark whether the buffer is used by outer broadcast tensors - std::unordered_map is_used_by_outer_bcast; - for (auto buffer : projectable_buffer_inputs) { - is_used_by_outer_bcast[buffer] = std::any_of( - outer_broadcast_tvs.begin(), - outer_broadcast_tvs.end(), - [&buffer](TensorView* tv) { - return DependencyCheck::isDependencyOf(buffer, tv); - }); - } - - // sort based on [is_used_by_outer_bcast] - std::vector sorted_buffer = projectable_buffer_inputs; - std::sort( - sorted_buffer.begin(), - sorted_buffer.end(), - [&](TensorView* a, TensorView* b) { - return !is_used_by_outer_bcast[a] && is_used_by_outer_bcast[b]; - }); - return sorted_buffer; -} - -PersistentBufferStorageParams getPersistentBufferStorageParams( - Fusion* fusion, - SchedulerRuntimeInfo& runtime_info, - HeuristicDataCache* data_cache, - const std::vector& reduction_tvs, - const int64_t vectorize_factor, - const int64_t threads_per_block_min, - const int64_t threads_per_block_max) { - FUSER_PERF_SCOPE( - "normalization_inner_outer::getPersistentBufferStorageParams"); - - PersistentBufferStorageParams buffer_params; - - auto persistent_buffer_info_entry = - HeuristicDataCacheEntry( - data_cache, [&fusion]() { - return std::make_unique( - scheduler_utils::persistentBuffers(fusion)); - }); - - auto& persistent_buffer_info = persistent_buffer_info_entry.get(); - - auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( - fusion, runtime_info, persistent_buffer_info, data_cache); - - // Project to inputs when there is at least one outer broadcast tensor or - // projected persistent buffer size is smaller. When projecting to inputs, the - // outer broadcast tensor is reused in the loop over the iteration dimension, - // test shows it is faster than the non-projected version which requires - // reload from gmem for each iteration. - // Note: in current use cases (layer norm bwd and RMS norm bwd), there are - // outer broadcast tvs and always project to inputs. - const auto& outer_broadcast_tvs = getOuterBroadcastTvs(fusion, reduction_tvs); - normalization_scheduler_utils::BufferProjectionStrategy project_strategy = - normalization_scheduler_utils::isProjectBufferToInputs( - fusion, - runtime_info, - reduction_tvs, - persistent_buffer_info, - persistent_buffer_size_info, - InnerOuterPersistentKernelScheduler::schedulerType(), - /*can_use_smem_persistent=*/true, - outer_broadcast_tvs.empty()); - - buffer_params.project_to_input = - (project_strategy == - normalization_scheduler_utils::BufferProjectionStrategy:: - ProjectToInputs); - - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( - fusion, reduction_tvs, threads_per_block_max); - int64_t available_smem = - (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; - int64_t available_regs = scheduler_utils::register_file_size_56k; - buffer_params.smem_overhead = smem_overhead; - - // (1) Use both register and shared memory. - // Start with all the cached input buffers in shared memory, they are loaded - // from global memory uses async copy which bypasses L1 cache. Outer reduction - // buffers are used to accumulate partial results of the outer reduction. They - // are not loaded from global memory and requires frequent read/write. So, - // they are always stored in registers. - // TODO: We may also move outer reduction buffers to shared - // memory to avoid segmentation when there are many outer reductions and - // hardware has larger shared memory, but these applications are rare, so this - // is not considered here. - auto buffers = buffer_params.project_to_input - ? persistent_buffer_info.projectable_buffer_inputs - : persistent_buffer_info.persistent_buffers; - - // Add buffers that are inputs to the fusion. They are not included in - // projectable_buffer_inputs since they are not projectable. - if (buffer_params.project_to_input) { - for (auto tv : persistent_buffer_info.persistent_buffers) { - if (tv->isFusionInput()) { - buffers.push_back(tv); - } - } - } - - // Needs to use rounded shared memory size to avoid over usage. - // key : buffer tv. - // val : register size and rounded shared memory size - std::unordered_map> - required_size_regs_smem_map; - int64_t total_smem_buffer_size = 0; - for (auto buffer : buffers) { - int64_t buffer_size_regs = scheduler_utils::getPersistentBufferSizeOfTensor( - buffer, runtime_info, persistent_buffer_info); - int64_t buffer_size_smem = roundUpSharedMemory( - buffer_size_regs, - dataTypeSize(buffer->getDataType().value()), - vectorize_factor, - threads_per_block_min, - threads_per_block_max, - dev_prop->warpSize); - required_size_regs_smem_map[buffer] = - std::make_pair(buffer_size_regs, buffer_size_smem); - total_smem_buffer_size += buffer_size_smem; - } - buffer_params.smem_buffer_size = total_smem_buffer_size; - buffer_params.regs_buffer_size = - partialOuterReductionBufferSize(reduction_tvs, runtime_info); - if (buffer_params.regs_buffer_size <= available_regs && - buffer_params.smem_buffer_size <= available_smem) { - buffer_params.smem_persistent_buffers = buffers; - buffer_params.has_enough_regs_and_smem = true; - return buffer_params; - } - - // Moving outer reduction buffer to shared memory is not considered yet, - // set to false if the outer reduction buffer size exceeds the register size. - if (buffer_params.regs_buffer_size > available_regs) { - buffer_params.has_enough_regs_and_smem = false; - return buffer_params; - } - - // (2) Now, shared memory is overused, move some buffers to registers. - // (2.1) Sort the candidate persistent buffers. No need to sort since the - // sorting is based on whether the buffer is used by outer broadcast tensors. - if (!outer_broadcast_tvs.empty()) { - buffers = sortProjectableBufferInputs(buffers, outer_broadcast_tvs); - } - // (2.2) Before this loop, all cached input buffers are in shared memory. Move - // buffer from shared memory to register. - int64_t n_regs_buffer = -1; - const int n_buffers = (int)buffers.size(); - for (int i = 0; i < n_buffers; i++) { - auto current_tv = buffers[i]; - auto [buffer_size_regs, buffer_size_smem] = - required_size_regs_smem_map.at(current_tv); - buffer_params.regs_buffer_size += buffer_size_regs; - buffer_params.smem_buffer_size -= buffer_size_smem; - - // The first-i buffers to are moved from shared memory to register - // If both the register buffer size and shared memory buffer size are within - // the allowable limit, we found a good configuration. - if (buffer_params.regs_buffer_size <= available_regs && - buffer_params.smem_buffer_size <= available_smem) { - n_regs_buffer = i + 1; - break; - } - // Register buffer size exceeds the limit, can't move more to registers. - // Break the loop. - if (buffer_params.regs_buffer_size > available_regs) { - break; - } - } - - // n_regs_buffer > 0 indicats a good configuration is found. - // The first n_regs_buffer buffers are stored in registers and last [n_buffers - // - n_regs_buffer] are stored in shared memory. - if (n_regs_buffer > 0) { - buffer_params.has_enough_regs_and_smem = true; - auto n_smem_buffer = n_buffers - n_regs_buffer; - buffer_params.smem_persistent_buffers.reserve(n_smem_buffer); - for (int i = 0; i < n_smem_buffer; i++) { - buffer_params.smem_persistent_buffers.emplace_back( - buffers[n_buffers - 1 - i]); - } - } else { - buffer_params.has_enough_regs_and_smem = false; - } - return buffer_params; -} - -// The innerOuterPersistentHeuristic is tuned for layer_norm backward on A100 -// ======= Method if hidden_size > 1024 ======= -// (1) Inner reduction is one reduction per block. Reduction domain is -// parallelized by TIDx and TIDy, Iteration domain is parallelized by BIDy. -// (2) Outer reduction is done in two-steps. The first step is partial -// reduction, reduction domain is parallelized by BIDy, iteration domain is -// parallelized by TIDx and TIDy. The partial results are written to gmem -// followed by a grid sync. The second step is block reduction, the reduction -// domain is parallelized by TIDy, the iteration domain is parallelized by TIDx -// and BIDy. -// ======= Method if hidden_size <= 1024 ======= -// (1) Inner reduction is multi-reductions per blocks. Reduction domain is -// parallelized by TIDx, Iteration domain is parallelized by BIDy and TIDy. -// (2) Outer reduction is same to cases where hidden_size > 1024 except the -// second step where in this case, the reduction domain is parallelized by TIDx -// and the iteration domain is parallelized by TIDy and BIDy. This switch -// between TIDx and TIDy is because: -// (a) We can do warp reduction with TIDx -// (b) TIDx*BIDy is usually much larger than hidden_size, e.g. 128*216 = 1024*27 -// this means without switch only 1/27 of the threads is used. -std::unique_ptr innerOuterPersistentHeuristic( - const int64_t outer_dim_numel, - const int64_t inner_dim_numel, - const int64_t regs_buffer_size, - const int64_t smem_buffer_size, - const int64_t smem_overhead, - const size_t tmp_gmem_dtype_size, - const size_t vectorize_factor, - const int64_t hp_threads_per_block_min, - const int64_t hp_threads_per_block_max, - const bool project_to_input, - const PrimDataType index_type) { - auto rparams = std::make_unique( - InnerOuterPersistentKernelScheduler::schedulerType()); - rparams->project_persistent_buffers = project_to_input; - rparams->cparams.index_type = index_type; - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t device_multiprocessor_count = - (int64_t)dev_prop->multiProcessorCount; - // Parameters for inner reduction: - // Reduction dim: inner_vect, inner_batch, bdimx and bdimy - // Iteration dim: gdimy - - // Parameters for outer reduction: - // Reduction dim: bdimy - // Iteration dim: vectorization_factor_outer, bdimx, gdimy - struct InnerOuterParams { - int64_t inner_vect = -1; - int64_t inner_batch = -1; - int64_t bdimx = -1; - int64_t bdimy = -1; - int64_t bdimz = -1; - int64_t gdimy = -1; - int64_t tmp_gmem_write_vect = -1; - int64_t vectorization_factor_outer = -1; - int64_t threads_per_block = -1; - // derived metrics for sorting - int64_t warps_per_sm = -1; - int64_t required_register_per_thread = -1; - int64_t available_register_per_thread = -1; - - void verify() { - NVF_ERROR(inner_vect != -1, "inner_vect is not set."); - NVF_ERROR(inner_batch != -1, "inner_batch is not set."); - NVF_ERROR(bdimx != -1, "bdimx is not set."); - NVF_ERROR(bdimy != -1, "bdimy is not set."); - NVF_ERROR(gdimy != -1, "gdimy is not set."); - NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); - NVF_ERROR( - vectorization_factor_outer != -1, - "vectorization_factor_outer is not set."); - } - std::string toString() const { - std::stringstream ss; - ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch - << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz - << ", gdimy: " << gdimy - << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect - << ", vectorization_factor_outer: " << vectorization_factor_outer - << ", threads_per_block: " << threads_per_block - << ", warps_per_sm: " << warps_per_sm - << ", required_register_per_thread: " << required_register_per_thread - << ", available_register_per_thread: " - << available_register_per_thread; - return ss.str(); - } - }; - - // Set a minimum workload for each thread to take advantage of low - // intra-threads communication cost. - // Tuned for layer_norm backward on A100, still works fine on H100. - auto getMinimumBatch = [&]() -> int64_t { - if (inner_dim_numel >= 3072l) { - if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { - return 3l; - } else { - return 4l; - } - } else if (inner_dim_numel >= 2048l) { - return 2l; - } - return 1l; - }; - - // Estimate register usage per thread based on buffer size. - // Assuming a constant register overhead for non-buffer related usage, - // and all the register buffers are stored in registers. - auto getEstimatedRegisterUsage = [&](int64_t batch_mul_vect) { - int64_t persistent_buffer_size = - regs_buffer_size / inner_dim_numel * batch_mul_vect; - int64_t estimated_register_count = - persistent_buffer_size / scheduler_utils::bytes_per_register + - scheduler_utils::register_overhead; - return std::min( - estimated_register_count, scheduler_utils::max_registers_per_thread); - }; - - // Estimate max blocks per sm based on register and shared memory usage. - auto getBlocksPerSM = [&](const int64_t threads_per_sm, - const int64_t threads_per_block, - const int64_t warp_size) { - // check register limitation on blocks per sm - constexpr int64_t warp_allocation_granularity = 4; - const int64_t allocated_warps_per_block = - ceilDiv( - ceilDiv(threads_per_block, warp_size), - warp_allocation_granularity) * - warp_allocation_granularity; - int64_t max_blocks_per_sm_regs = scheduler_utils::safeDiv( - threads_per_sm / warp_size, allocated_warps_per_block); - // check shared memory limitation on blocks per sm - int64_t max_blocks_per_sm_smem = - (int64_t)dev_prop->sharedMemPerMultiprocessor / - (smem_overhead + smem_buffer_size); - return std::min(max_blocks_per_sm_regs, max_blocks_per_sm_smem); - }; - - // In the inner reduction part of the kernel, gdimy is used to parallelize the - // outer dimension. The kernel is a cooperative kernel, so the number of - // blocks should be as large as possible to achieve a high occupancy unless - // outer dim is too small which may lead large workload for the final outer - // reduction. So, gdimy is drvied from the number of blocks per sm and limited - // to ensure at least 8 rows per block. - // TODO: re-evaluate this 8 rows per block requirement. - auto getGdimy = [&](int64_t inner_vect, - int64_t threads_per_block, - int64_t inner_batch) { - int64_t reg_per_thread = - getEstimatedRegisterUsage(inner_vect * inner_batch); - int64_t threads_per_sm = getThreadsPerSMGivenRegPerThread(reg_per_thread); - int64_t blocks_per_sm = - getBlocksPerSM(threads_per_sm, threads_per_block, dev_prop->warpSize); - int64_t gdimy = blocks_per_sm * device_multiprocessor_count; - const int64_t outer_iter_min = 8; - const int64_t gdimy_max = scheduler_utils::roundUpToN( - ceilDiv(outer_dim_numel, outer_iter_min), device_multiprocessor_count); - while (gdimy > gdimy_max && blocks_per_sm > 1) { - blocks_per_sm -= 1; - gdimy = blocks_per_sm * device_multiprocessor_count; - } - return gdimy; - }; - - // The inner reduction part of the kernel also does a partial outer reduction - // and stores the partial results in tmp gmem and then reloaded to finish the - // outer reduciton. This function set the vectorization factor for write and - // and read of the partial outer reduction result. - // For write to tmp gmem, follows vectorization factor of inner reduction - // but don't exceed 16 bytes. - // For read from tmp gmem, since the paralelization is changed, a different - // vectorization factor is used to optimize the - // number of reaductions per thread. - auto getOuterReductionBufferVectFactor = [&](int64_t inner_vect) { - constexpr int64_t max_gmem_vect_access_bytes = 16; - const int64_t max_tmp_gmem_vect_factor = std::min( - max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); - int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; - const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; - int64_t vectorization_factor_outer = - std::min(workload_per_thread, max_tmp_gmem_vect_factor); - return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); - }; - - // In the outer reduction part of the kernel, inner and outer dims are - // parallelized as: - // --- inner dim: vect, bdimx, gdimy ---- - // --- outer dim: bdimy ----------------- - // This function splits the threads_per_block into bdimx and bdimy using: - // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) - // bdimy = threads_per_block / bdimx - auto getBdimxBdimy = [&](int64_t threads_per_block, - int64_t vectorization_factor_outer, - int64_t gdimy) { - // For widely used hidden sizes, threads_per_block has factor of 8, roundup - // to increase the probability of bdimx * bdimy == threads_per_block. - int64_t bdimx = scheduler_utils::roundUpPow2Or8( - ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); - // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. - // increase bdimx to make it divisible. Under worst case, bdimx equals to - // threads_per_block. - while (threads_per_block % bdimx) { - bdimx = std::min(bdimx + 8, threads_per_block); - } - // Set OuterParams Reduction dim: bdimy. - int64_t bdimy = threads_per_block / bdimx; - NVF_ERROR( - bdimy * bdimx == threads_per_block, - " threads_per_block must be divisible by bdimx and bdimy."); - return std::make_pair(bdimx, bdimy); - }; - - // Get the heuristics given vectorization factor and threads per block - auto getHeuristicsGivenVectThreads = [&](int64_t vect_factor, - int64_t threads_per_block) { - InnerOuterParams iop; - // (1) inner reduction - // Reduction dim: inner_batch, threads_per_block, vect_factor - // Iteration dim: gdimy - iop.inner_vect = vect_factor; - iop.threads_per_block = threads_per_block; - iop.inner_batch = - ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); - iop.gdimy = - getGdimy(iop.inner_vect, iop.threads_per_block, iop.inner_batch); - // (2) outer reduction - // Iteration dim: gdimy, bdimx, vectorization_factor_outer - // Reduction dim: bdimy - std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = - getOuterReductionBufferVectFactor(iop.inner_vect); - auto [bdimx, bdimy] = getBdimxBdimy( - threads_per_block, iop.vectorization_factor_outer, iop.gdimy); - iop.bdimx = bdimx; - iop.bdimy = bdimy; - // (3) Derived metrics warps_per_sm and register usage for sorting - iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - iop.required_register_per_thread = - getEstimatedRegisterUsage(iop.inner_vect * iop.inner_batch); - return iop; - }; - - // Use the maximum vectorization factor - const int64_t vect_factor = (int64_t)vectorize_factor; - - // Set a reasonable range for threads per block based on the number of - // elements in the inner dimension after vectorization. - // Start from 128 or a smaller number if inner dim is small. - const int64_t after_vect = inner_dim_numel / vect_factor; - const int64_t batch_min = getMinimumBatch(); - int64_t threads_per_block_min = hp_threads_per_block_min; - threads_per_block_min = std::min(threads_per_block_min, after_vect); - threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); - - // star max threads per block from min threads per block - int64_t threads_per_block_max = threads_per_block_min; - // increase to cover the whole inner dim - threads_per_block_max = - std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); - // round up to power of 2 - threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); - // don't go beyond the maximum threads per block - threads_per_block_max = - std::min(threads_per_block_max, hp_threads_per_block_max); - - // Store all the possible heuristics based on different threads per block. - // Vectorizaton is fixed at the maximum value. - std::vector iop_candidates; - for (auto threads_per_block = threads_per_block_max; - threads_per_block >= threads_per_block_min; - threads_per_block /= 2) { - iop_candidates.emplace_back( - getHeuristicsGivenVectThreads(vect_factor, threads_per_block)); - } - - // Sort the heuristics based on the register usage and occupancy. - std::stable_sort( - iop_candidates.begin(), - iop_candidates.end(), - [](const InnerOuterParams& a, const InnerOuterParams& b) { - // If a thread can use more registers than required, there is a high - // chance that it can avoid register spilling and compiler can optimize - // for better instruction level parallelism. - int64_t extra_regs_a = - a.available_register_per_thread - a.required_register_per_thread; - int64_t extra_regs_b = - b.available_register_per_thread - b.required_register_per_thread; - if (extra_regs_a > 0 && extra_regs_b < 0) { - return true; - } else if (extra_regs_a < 0 && extra_regs_b > 0) { - return false; - } - // High occupancy provides better threads level parallelism. - // 25% is sufficient since ILP is high due to persistent batch sizes - // which is equivalent to unrolling inner dim. - if (a.warps_per_sm != b.warps_per_sm && - (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { - return a.warps_per_sm > b.warps_per_sm; - } - // Tie breaker, smaller threads_per_block to reduce communication - // overhead - return a.threads_per_block < b.threads_per_block; - }); - - // Pick the best heuristic - auto iop = iop_candidates.front(); - - // Special case, when inner_dim_numel <= 1024, bdimx is usually small - // after divide by inner_vect and inner_batch. In this case, bdimy is used to - // parallelize outer_dim instead of inner_dim. This pattern is named multi - // reductions per block (mrpb). - if (inner_dim_numel <= 1024) { - rparams->multiple_reds_per_blk = true; - rparams->tidx_for_outer_reduction = true; - - // Step-1, InnerParams, Reduction dim: inner_vect(reuse), - // inner_batch(reuse), bdimx - iop.bdimx = ceilDiv(inner_dim_numel, iop.inner_vect * iop.inner_batch); - - // Step-2, InnerParams, Iteration dim: gdimy, bdimy (in next step) - iop.gdimy = getGdimy(iop.inner_vect, iop.bdimx, iop.inner_batch); - - // Step-3, OuterParams, Iteration dim: vectorization_factor_outer(reuse), - // bdimy, gdimy (in previous step). - // WAR for https://github.com/NVIDIA/Fuser/issues/3428 - iop.bdimy = 1; - - // Step-4, OuterParams, Reduction dim: bdimx (already done) - iop.warps_per_sm = ceilDiv(iop.bdimx * iop.bdimy, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - - if (iop.bdimx % dev_prop->warpSize == 0) { - rparams->pad_inner_reduction_to_warp = true; - rparams->pad_outer_reduction_to_warp = true; - } - rparams->block_dim_iter_dom = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimy * iop.gdimy < - inner_dim_numel; - } else { - rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < - inner_dim_numel; - rparams->static_bdimx = true; - rparams->static_bdimy = true; - iop.bdimz = ceilDiv( - ceilDiv( - ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), - iop.inner_batch); - NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); - } - - // check all the parameters in InnerOuterParams are set. - iop.verify(); - - rparams->persistent_kernel = true; - rparams->fastest_dim = true; - rparams->combined_inner_outer = true; - // tmp_gmem is the intermediate result of outer reduction, its dtype is float, - // so the maximum vectorization factor is 4. - rparams->vectorization_factor_outer = iop.vectorization_factor_outer; - rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; - rparams->cparams.maxrregcount = iop.available_register_per_thread; - rparams->unroll_factor_inner_reduction = iop.inner_vect; - rparams->batches_per_block_inner_reduction = iop.inner_batch; - rparams->block_dim_inner_reduction = ParallelType::TIDx; - rparams->vectorize_inner_reduction = iop.inner_vect > 1; - rparams->split_grid_dim_iter_dom_outer = true; - rparams->grid_dim_iter_dom = ParallelType::BIDy; - - rparams->lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - iop.gdimy, - LaunchParams::UNINITIALIZED_VAL, - iop.bdimx, - iop.bdimy, - LaunchParams::UNINITIALIZED_VAL); - - if (!rparams->smem_persistent_buffers.empty()) { - rparams->tag = - "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; - } else { - rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; - } - - if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { - debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" - << "outer_dim_numel: " << outer_dim_numel << "\n" - << "inner_dim_numel: " << inner_dim_numel << "\n" - << "regs_buffer_size: " << regs_buffer_size << "\n" - << "smem_buffer_size: " << smem_buffer_size << "\n" - << "smem_overhead: " << smem_overhead << "\n" - << "vectorize_factor_input: " << iop.inner_vect << "\n" - << "vectorization_factor_tmp_gmem_write: " - << iop.tmp_gmem_write_vect << "\n" - << "vectorization_factor_outer: " << iop.vectorization_factor_outer - << "\n" - << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk - << "\n" - << "warps_per_sm: " << iop.warps_per_sm << "\n" - << "gdimy: " << iop.gdimy << "\n" - << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; - debug() << rparams->toString() << std::endl; - } - return rparams; -} - -std::unique_ptr innerOuterWarpSpecializedTmaHeuristic( - const int64_t outer_dim_numel, - const int64_t inner_dim_numel, - const int64_t regs_buffer_size, - const int64_t smem_buffer_size, - const int64_t smem_overhead, - const size_t tmp_gmem_dtype_size, - const size_t vectorize_factor, - const int64_t hp_threads_per_block_min, - const int64_t hp_threads_per_block_max, - const bool project_to_input, - const PrimDataType index_type) { - auto rparams = std::make_unique( - InnerOuterPersistentKernelScheduler::schedulerType()); - rparams->project_persistent_buffers = project_to_input; - rparams->cparams.index_type = index_type; - const auto dev_prop = at::cuda::getCurrentDeviceProperties(); - const int64_t device_multiprocessor_count = - (int64_t)dev_prop->multiProcessorCount; - // Parameters for inner reduction: - // Reduction dim: inner_vect, inner_batch, bdimx and bdimy - // Iteration dim: gdimy - - // Parameters for outer reduction: - // Reduction dim: bdimy - // Iteration dim: vectorization_factor_outer, bdimx, gdimy - struct InnerOuterParams { - int64_t inner_vect = -1; - int64_t inner_batch = -1; - int64_t bdimx = -1; - int64_t bdimy = -1; - int64_t bdimz = -1; - int64_t gdimy = -1; - int64_t tmp_gmem_write_vect = -1; - int64_t vectorization_factor_outer = -1; - int64_t threads_per_block = -1; - // derived metrics for sorting - int64_t warps_per_sm = -1; - int64_t required_register_per_thread = -1; - int64_t available_register_per_thread = -1; - - void verify() { - NVF_ERROR(inner_vect != -1, "inner_vect is not set."); - NVF_ERROR(inner_batch != -1, "inner_batch is not set."); - NVF_ERROR(bdimx != -1, "bdimx is not set."); - NVF_ERROR(bdimy != -1, "bdimy is not set."); - NVF_ERROR(gdimy != -1, "gdimy is not set."); - NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); - NVF_ERROR( - vectorization_factor_outer != -1, - "vectorization_factor_outer is not set."); - } - std::string toString() const { - std::stringstream ss; - ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch - << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz - << ", gdimy: " << gdimy - << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect - << ", vectorization_factor_outer: " << vectorization_factor_outer - << ", threads_per_block: " << threads_per_block - << ", warps_per_sm: " << warps_per_sm - << ", required_register_per_thread: " << required_register_per_thread - << ", available_register_per_thread: " - << available_register_per_thread; - return ss.str(); - } - }; - - // Set a minimum workload for each thread to take advantage of low - // intra-threads communication cost. - // Tuned for layer_norm backward on A100, still works fine on H100. - auto get_minimum_batch = [&]() -> int64_t { - if (inner_dim_numel >= 3072l) { - if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { - return 3l; - } else { - return 4l; - } - } else if (inner_dim_numel >= 2048l) { - return 2l; - } - return 1l; - }; - - // Estimate register usage per thread based on buffer size. - // Assuming a constant register overhead for non-buffer related usage, - // and all the register buffers are stored in registers. - auto get_estimated_register_usage = [&](int64_t batch_mul_vect) { - int64_t persistent_buffer_size = - regs_buffer_size / inner_dim_numel * batch_mul_vect; - int64_t estimated_register_count = - persistent_buffer_size / scheduler_utils::bytes_per_register + - scheduler_utils::register_overhead; - return std::min( - estimated_register_count, scheduler_utils::max_registers_per_thread); - }; - - // The inner reduction part of the kernel also does a partial outer reduction - // and stores the partial results in tmp gmem and then reloaded to finish the - // outer reduciton. This function set the vectorization factor for write and - // and read of the partial outer reduction result. - // For write to tmp gmem, follows vectorization factor of inner reduction - // but don't exceed 16 bytes. - // For read from tmp gmem, since the paralelization is changed, a different - // vectorization factor is used to optimize the - // number of reaductions per thread. - auto get_outer_reduction_buffer_vect_factor = [&](int64_t inner_vect) { - constexpr int64_t max_gmem_vect_access_bytes = 16; - const int64_t max_tmp_gmem_vect_factor = std::min( - max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); - int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; - const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; - int64_t vectorization_factor_outer = - std::min(workload_per_thread, max_tmp_gmem_vect_factor); - return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); - }; - - // In the outer reduction part of the kernel, inner and outer dims are - // parallelized as: - // --- inner dim: vect, bdimx, gdimy ---- - // --- outer dim: bdimy ----------------- - // This function splits the threads_per_block into bdimx and bdimy using: - // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) - // bdimy = threads_per_block / bdimx - auto get_bdimx_bdimy = [&](int64_t threads_per_block, - int64_t vectorization_factor_outer, - int64_t gdimy) { - // For widely used hidden sizes, threads_per_block has factor of 8, roundup - // to increase the probability of bdimx * bdimy == threads_per_block. - int64_t bdimx = scheduler_utils::roundUpPow2Or8( - ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); - // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. - // increase bdimx to make it divisible. Under worst case, bdimx equals to - // threads_per_block. - while (threads_per_block % bdimx) { - bdimx = std::min(bdimx + 8, threads_per_block); - } - // Set OuterParams Reduction dim: bdimy. - int64_t bdimy = threads_per_block / bdimx; - NVF_ERROR( - bdimy * bdimx == threads_per_block, - " threads_per_block must be divisible by bdimx and bdimy."); - return std::make_pair(bdimx, bdimy); - }; - - // Get the heuristics given vectorization factor and threads per block - auto get_heuristics_given_vect_threads = [&](int64_t vect_factor, - int64_t threads_per_block) { - InnerOuterParams iop; - // (1) inner reduction - // Reduction dim: inner_batch, threads_per_block, vect_factor - // Iteration dim: gdimy - iop.inner_vect = vect_factor; - iop.threads_per_block = threads_per_block; - iop.inner_batch = - ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); - iop.gdimy = device_multiprocessor_count; - - // (2) outer reduction - // Iteration dim: gdimy, bdimx, vectorization_factor_outer - // Reduction dim: bdimy - std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = - get_outer_reduction_buffer_vect_factor(iop.inner_vect); - auto [bdimx, bdimy] = get_bdimx_bdimy( - threads_per_block, iop.vectorization_factor_outer, iop.gdimy); - iop.bdimx = bdimx; - iop.bdimy = bdimy; - // (3) Derived metrics warps_per_sm and register usage for sorting - iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * - iop.gdimy / device_multiprocessor_count; - iop.available_register_per_thread = - getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); - iop.required_register_per_thread = - get_estimated_register_usage(iop.inner_vect * iop.inner_batch); - return iop; - }; - - // Use the maximum vectorization factor - const int64_t vect_factor = (int64_t)vectorize_factor; - - // Set a reasonable range for threads per block based on the number of - // elements in the inner dimension after vectorization. - // Start from 128 or a smaller number if inner dim is small. - const int64_t after_vect = inner_dim_numel / vect_factor; - const int64_t batch_min = get_minimum_batch(); - int64_t threads_per_block_min = hp_threads_per_block_min; - threads_per_block_min = std::min(threads_per_block_min, after_vect); - threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); - - // star max threads per block from min threads per block - int64_t threads_per_block_max = threads_per_block_min; - // increase to cover the whole inner dim - threads_per_block_max = - std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); - // round up to power of 2 - threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); - // don't go beyond the maximum threads per block - threads_per_block_max = - std::min(threads_per_block_max, hp_threads_per_block_max); - - // Store all the possible heuristics based on different threads per block. - // Vectorizaton is fixed at the maximum value. - std::vector iop_candidates; - for (auto threads_per_block = threads_per_block_max; - threads_per_block >= threads_per_block_min; - threads_per_block /= 2) { - iop_candidates.emplace_back( - get_heuristics_given_vect_threads(vect_factor, threads_per_block)); - } - - // Sort the heuristics based on the register usage and occupancy. - std::stable_sort( - iop_candidates.begin(), - iop_candidates.end(), - [](const InnerOuterParams& a, const InnerOuterParams& b) { - // If a thread can use more registers than required, there is a high - // chance that it can avoid register spilling and compiler can optimize - // for better instruction level parallelism. - int64_t extra_regs_a = - a.available_register_per_thread - a.required_register_per_thread; - int64_t extra_regs_b = - b.available_register_per_thread - b.required_register_per_thread; - if (extra_regs_a > 0 && extra_regs_b < 0) { - return true; - } else if (extra_regs_a < 0 && extra_regs_b > 0) { - return false; - } - // High occupancy provides better threads level parallelism. - // 25% is sufficient since ILP is high due to persistent batch sizes - // which is equivalent to unrolling inner dim. - if (a.warps_per_sm != b.warps_per_sm && - (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { - return a.warps_per_sm > b.warps_per_sm; - } - // Tie breaker, smaller threads_per_block to reduce communication - // overhead - return a.threads_per_block < b.threads_per_block; - }); - - // Pick the best heuristic - auto iop = iop_candidates.front(); - rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; - rparams->combined_split_grid_inner_dim = - iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < inner_dim_numel; - rparams->static_bdimx = true; - rparams->static_bdimy = true; - iop.bdimz = ceilDiv( - ceilDiv(ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), - iop.inner_batch); - NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); - - // check all the parameters in InnerOuterParams are set. - iop.verify(); - - rparams->persistent_kernel = true; - rparams->fastest_dim = true; - rparams->combined_inner_outer = true; - // tmp_gmem is the intermediate result of outer reduction, its dtype is float, - // so the maximum vectorization factor is 4. - rparams->vectorization_factor_outer = iop.vectorization_factor_outer; - rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; - rparams->cparams.maxrregcount = iop.available_register_per_thread; - rparams->unroll_factor_inner_reduction = iop.inner_vect; - rparams->batches_per_block_inner_reduction = iop.inner_batch; - rparams->block_dim_inner_reduction = ParallelType::TIDx; - rparams->vectorize_inner_reduction = iop.inner_vect > 1; - rparams->split_grid_dim_iter_dom_outer = true; - rparams->grid_dim_iter_dom = ParallelType::BIDy; - - rparams->lparams = LaunchParams( - LaunchParams::UNINITIALIZED_VAL, - iop.gdimy, - LaunchParams::UNINITIALIZED_VAL, - iop.bdimx, - iop.bdimy, - LaunchParams::UNINITIALIZED_VAL); - - if (!rparams->smem_persistent_buffers.empty()) { - rparams->tag = - "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; - } else { - rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; - } - - if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { - debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" - << "outer_dim_numel: " << outer_dim_numel << "\n" - << "inner_dim_numel: " << inner_dim_numel << "\n" - << "regs_buffer_size: " << regs_buffer_size << "\n" - << "smem_buffer_size: " << smem_buffer_size << "\n" - << "smem_overhead: " << smem_overhead << "\n" - << "vectorize_factor_input: " << iop.inner_vect << "\n" - << "vectorization_factor_tmp_gmem_write: " - << iop.tmp_gmem_write_vect << "\n" - << "vectorization_factor_outer: " << iop.vectorization_factor_outer - << "\n" - << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk - << "\n" - << "warps_per_sm: " << iop.warps_per_sm << "\n" - << "gdimy: " << iop.gdimy << "\n" - << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; - debug() << rparams->toString() << std::endl; - } - return rparams; -} - std::unique_ptr getInnerOuterPersistentHeuristics( Fusion* fusion, SchedulerRuntimeInfo& runtime_info, @@ -1143,7 +96,7 @@ std::unique_ptr getInnerOuterPersistentHeuristics( NVF_ERROR( !persistent_buffer_info.persistent_buffers.empty(), "Persistent scheduler requires persistent buffers."); - auto buffer_params = getPersistentBufferStorageParams( + auto buffer_params = inner_outer_utils::getPersistentBufferStorageParams( fusion, runtime_info, data_cache, @@ -1152,14 +105,15 @@ std::unique_ptr getInnerOuterPersistentHeuristics( hp.threads_per_block_min, hp.threads_per_block_max); - std::unique_ptr rparams; - + auto rparams = std::make_unique( + InnerOuterPersistentKernelScheduler::schedulerType()); // Ultimately, we want the heuristic to decide between using the // warp-specialized version or the multi-wave version. The enable option is a // temporary configuration to facilitate testing during development without // disrupting existing behavior. if (isOptionEnabled(EnableOption::WarpSpecializedNormalization)) { - rparams = innerOuterWarpSpecializedTmaHeuristic( + inner_outer_tma_warp_specialized::getHeuristics( + rparams.get(), properties.total_iteration_numel, properties.total_reduction_numel, buffer_params.regs_buffer_size, @@ -1171,9 +125,9 @@ std::unique_ptr getInnerOuterPersistentHeuristics( hp.threads_per_block_max, buffer_params.project_to_input, runtime_info.getIndexType()); - rparams->tma_warp_specialized = true; } else { - rparams = innerOuterPersistentHeuristic( + inner_outer_multi_wave::getHeuristics( + rparams.get(), properties.total_iteration_numel, properties.total_reduction_numel, buffer_params.regs_buffer_size, @@ -1194,544 +148,6 @@ std::unique_ptr getInnerOuterPersistentHeuristics( return rparams; } -void scheduleReductionCombinedOuter( - Fusion* fusion, - const ReductionParams* rparams, - const std::vector& outer_reduction_tvs, - std::vector& cached_gmem, - std::vector& cached_gmem_reload, - std::vector& outer_reference_tvs, - std::unordered_set& boundaryNodesSet) { - auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { - int prev_i = -1; - for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (mergeReduction == tv->axis(i)->isReduction()) { - if (prev_i == -1) { - prev_i = i; - } else { - tv->merge(i, prev_i); - prev_i = i; - } - } - } - }; - for (auto& outer_reduction_tv : outer_reduction_tvs) { - // Similar to the inner reduction, we need to reorder the outer reduction tv - // when there are view operations. - if (!ir_utils::getViewOps(fusion).empty()) { - // Reorder reference_tv after propagating the view operation. This will - // reorder for better merging. - outer_reduction_tv->reorder( - scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); - } - - // merge tensorview to [reduction, iteraiton] domains - mergeReductionOrIterDomains(outer_reduction_tv, true); - mergeReductionOrIterDomains(outer_reduction_tv, false); - if (rparams->multiple_reds_per_blk) { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(rparams->block_dim_iter_dom)); - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(rparams->grid_dim_iter_dom), false); - } else { - outer_reduction_tv->split(0, rparams->lparams.gdimy()); - } - - if (rparams->multiple_reds_per_blk) { - outer_reduction_tv->rFactor({1}); - } - TensorView* partialResult = rparams->multiple_reds_per_blk - ? outer_reduction_tv->rFactor({1}) - : outer_reduction_tv->rFactor({0}); - partialResult->cacheBefore(); - partialResult->setMemoryType(MemoryType::Global); - TensorView* partialResultReload = partialResult->cacheAfter(); - - boundaryNodesSet.insert(partialResultReload); - cached_gmem.emplace_back(partialResult); - cached_gmem_reload.emplace_back(partialResultReload); - - if (rparams->multiple_reds_per_blk) { - if (rparams->tidx_for_outer_reduction) { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(ParallelType::TIDx)); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDx); - // to use warp reduction - if (rparams->pad_outer_reduction_to_warp) { - outer_reduction_tv->axis(1)->padToMultipleOfWarp(); - } - } else { - outer_reduction_tv->split( - 0, NamedScalar::getParallelDim(ParallelType::TIDy)); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - } - // iteration domain - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize( - ParallelType::Vectorize); - } - - if (rparams->tidx_for_outer_reduction) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::TIDy)); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDy); - } else { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::TIDx)); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - - } else { - // reduction domain - outer_reduction_tv->split(0, rparams->lparams.bdimy()); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - - // iteration domain - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize( - ParallelType::Vectorize); - } - - if (rparams->lparams.bdimx() > 1) { - outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - } - auto outer_reference_tv = - reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); - outer_reference_tvs.emplace_back(outer_reference_tv); - } -} - -// fusion is the input IR that will be modified by this function -void scheduleInnerOuterPersistentKernel( - Fusion* fusion, - const ReductionParams* rparams) { - FusionGuard fg(fusion); - - // Grab the reduction, input, and output tensor views. dummy_outputs are - // helper tensors for persistent buffer projection. - std::vector dummy_outputs, cached_inputs, reduction_tvs, - smem_consumers; - std::vector> cached_outputs; - normalization_scheduler_utils::beforeSchedule( - fusion, - rparams, - dummy_outputs, - cached_inputs, - reduction_tvs, - smem_consumers, - cached_outputs); - - // split reduction_tvs into inner and outer reduction_tvs - std::vector inner_reduction_tvs, outer_reduction_tvs; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - inner_reduction_tvs.emplace_back(tv); - } else { - outer_reduction_tvs.emplace_back(tv); - } - } - NVF_ERROR( - !inner_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); - NVF_ERROR( - !outer_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); - - // schedule inner reduction, only schedule the first inner reduction tv, - // then will be propagated to other inner reduction tvs. - TensorView* inner_reference_tv = - normalization_scheduler_utils::scheduleReductionGeneral( - fusion, - rparams, - inner_reduction_tvs, - InnerOuterPersistentKernelScheduler::schedulerType()); - - // schedule outer reduction, schedule all the outer reduction tvs since we - // need to store the intermediate results. - std::vector cached_gmem; - std::vector cached_gmem_reload; - std::vector outer_reference_tvs; - std::unordered_set boundaryNodesSet; - scheduleReductionCombinedOuter( - fusion, - rparams, - outer_reduction_tvs, - cached_gmem, - cached_gmem_reload, - outer_reference_tvs, - boundaryNodesSet); - - // Propagate inner reduction and outer reductions - for (auto output : dummy_outputs) { - fusion->addOutput(output); - } - - const bool is_unroll_or_vectorization = rparams->isUnrolled(); - const bool is_vectorize = - rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; - const bool is_outer_grid_persistence = rparams->persistent_kernel && - rparams->cross_grid_inner_reduction && !rparams->fastest_dim; - - // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this - // propagation will not propagate to the final outer reduction. - reduction_scheduler_utils::propagateTransformation( - inner_reference_tv, boundaryNodesSet); - reduction_scheduler_utils::propagateRFactor( - inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); - - // Don't allow parallelization propagation goes through boundaryNodesSet - const auto& selected_tvs_inner = - scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); - reduction_scheduler_utils::propagateParallelization( - inner_reduction_tvs[0], - inner_reference_tv, - is_unroll_or_vectorization, - is_outer_grid_persistence, - inner_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_inner.begin(), selected_tvs_inner.end()}); - - // Propagate outer reduction. Each outer reduction is connected with its - // cached_gmem and output, since we added all the cached_gmem to the - // boundaryNodesSet, the transformation from one outer reduction can't - // propagate to other outer reductions due to the cutoff at - // boundaryNodesSet. Thus, we need a loop to initiate the propagation from - // each outer reduction. Don't allow parallelization propagation goes - // through cached_gmem, see issue 246. - for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { - const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( - {outer_reduction_tvs[i]}, {cached_gmem[i]}); - reduction_scheduler_utils::propagateTransformation( - outer_reference_tvs[i], boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - outer_reference_tvs[i], - is_vectorize, - cached_inputs, - cached_outputs); - reduction_scheduler_utils::propagateParallelization( - outer_reduction_tvs[i], - outer_reference_tvs[i], - is_unroll_or_vectorization, - is_outer_grid_persistence, - outer_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_outer.begin(), selected_tvs_outer.end()}); - } - - // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write - // is guaranteed to be smaller or equal to input vectorization factor. - if (rparams->vectorization_factor_tmp_gmem_write > 1) { - for (auto tv : cached_gmem) { - NVF_ERROR( - rparams->vectorization_factor_tmp_gmem_write <= - rparams->unroll_factor_inner_reduction, - "vectorization factor of temp gmem write should be smaller than that of inner reduction.") - if (rparams->vectorization_factor_tmp_gmem_write < - rparams->unroll_factor_inner_reduction) { - tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - } - // vectorization propagate through propagateParallelization only works for - // input and output tensors. propagate vectorization to cached_gmem_reload - // directly from output tv using parallelizeAllLike. must propagate - // seperaely for different tvs as outer reductions are transformed - // seperately. - if (rparams->vectorization_factor_outer > 1) { - for (auto tv : cached_gmem_reload) { - auto output_tvs = ir_utils::outputTvsOf(tv); - NVF_ERROR( - !output_tvs.empty(), - "cached_gmem_reload should have at least one output tensor.") - scheduler_utils::parallelizeAllLike( - output_tvs[0], - -1, - {cached_gmem_reload.begin(), cached_gmem_reload.end()}, - {ParallelType::Vectorize}); - } - } - - // Needs special handling of vectorized loading from shared memory due to - // potential different data types of inputs and shared memory tensor. - if (is_vectorize) { - reduction_scheduler_utils::sharedMemoryConsumerVectorization( - smem_consumers, rparams->unroll_factor_inner_reduction); - } - - // Remove dummy outputs as they can inadvertently affect CA positions - for (auto output : dummy_outputs) { - fusion->removeOutput(output); - } - inlineMost(); -} - -void scheduleTmaWarpSpecializedOuter( - Fusion* fusion, - const ReductionParams* rparams, - const std::vector& outer_reduction_tvs, - std::vector& cached_gmem, - std::vector& cached_gmem_reload, - std::vector& outer_reference_tvs, - std::unordered_set& boundaryNodesSet) { - auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { - int prev_i = -1; - for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { - if (mergeReduction == tv->axis(i)->isReduction()) { - if (prev_i == -1) { - prev_i = i; - } else { - tv->merge(i, prev_i); - prev_i = i; - } - } - } - }; - for (auto& outer_reduction_tv : outer_reduction_tvs) { - // Similar to the inner reduction, we need to reorder the outer reduction tv - // when there are view operations. - if (!ir_utils::getViewOps(fusion).empty()) { - // Reorder reference_tv after propagating the view operation. This will - // reorder for better merging. - outer_reduction_tv->reorder( - scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); - } - - // merge tensorview to [reduction, iteraiton] domains - mergeReductionOrIterDomains(outer_reduction_tv, true); - mergeReductionOrIterDomains(outer_reduction_tv, false); - - // First-stage of outer reduction - outer_reduction_tv->split(0, rparams->lparams.gdimy()); - - TensorView* partialResult = outer_reduction_tv->rFactor({0}); - partialResult->cacheBefore(); - partialResult->setMemoryType(MemoryType::Global); - TensorView* partialResultReload = partialResult->cacheAfter(); - - boundaryNodesSet.insert(partialResultReload); - cached_gmem.emplace_back(partialResult); - cached_gmem_reload.emplace_back(partialResultReload); - - // Second-stage of outer reduction - // reduction domain, [I1/TIDy, TIDy] - outer_reduction_tv->split(0, rparams->lparams.bdimy()); - outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); - // iteration domain, [BIDy, TIDx, Vect] - int axisID = -1; - if (rparams->vectorization_factor_outer > 1) { - outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Vectorize); - } - - if (rparams->lparams.bdimx() > 1) { - outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); - } - - if (rparams->combined_split_grid_inner_dim) { - outer_reduction_tv->split( - axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); - } - - outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); - - auto outer_reference_tv = - reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); - outer_reference_tvs.emplace_back(outer_reference_tv); - } -} - -void scheduleTmaWarpSpecializedInnerOuter( - Fusion* fusion, - const ReductionParams* rparams) { - FusionGuard fg(fusion); - - // Grab the reduction, input, and output tensor views. dummy_outputs are - // helper tensors for persistent buffer projection. - std::vector dummy_outputs, cached_inputs, reduction_tvs, - smem_consumers; - std::vector> cached_outputs; - normalization_scheduler_utils::beforeSchedule( - fusion, - rparams, - dummy_outputs, - cached_inputs, - reduction_tvs, - smem_consumers, - cached_outputs); - - // split reduction_tvs into inner and outer reduction_tvs - std::vector inner_reduction_tvs, outer_reduction_tvs; - for (auto tv : reduction_tvs) { - if (scheduler_utils::isFastestDimReduction(tv)) { - inner_reduction_tvs.emplace_back(tv); - } else { - outer_reduction_tvs.emplace_back(tv); - } - } - NVF_ERROR( - !inner_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); - NVF_ERROR( - !outer_reduction_tvs.empty(), - "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); - - // schedule inner reduction, only schedule the first inner reduction tv, - // then will be propagated to other inner reduction tvs. - TensorView* inner_reference_tv = - normalization_scheduler_utils::scheduleReductionGeneral( - fusion, - rparams, - inner_reduction_tvs, - InnerOuterPersistentKernelScheduler::schedulerType()); - - // schedule outer reduction, schedule all the outer reduction tvs since we - // need to store the intermediate results. - std::vector cached_gmem; - std::vector cached_gmem_reload; - std::vector outer_reference_tvs; - std::unordered_set boundaryNodesSet; - scheduleTmaWarpSpecializedOuter( - fusion, - rparams, - outer_reduction_tvs, - cached_gmem, - cached_gmem_reload, - outer_reference_tvs, - boundaryNodesSet); - - // Propagate inner reduction and outer reductions - for (auto output : dummy_outputs) { - fusion->addOutput(output); - } - - const bool is_unroll_or_vectorization = rparams->isUnrolled(); - const bool is_vectorize = - rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; - const bool is_outer_grid_persistence = rparams->persistent_kernel && - rparams->cross_grid_inner_reduction && !rparams->fastest_dim; - - // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this - // propagation will not propagate to the final outer reduction. - reduction_scheduler_utils::propagateTransformation( - inner_reference_tv, boundaryNodesSet); - reduction_scheduler_utils::propagateRFactor( - inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); - - // Don't allow parallelization propagation goes through boundaryNodesSet - const auto& selected_tvs_inner = - scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); - reduction_scheduler_utils::propagateParallelization( - inner_reduction_tvs[0], - inner_reference_tv, - is_unroll_or_vectorization, - is_outer_grid_persistence, - inner_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_inner.begin(), selected_tvs_inner.end()}); - - // Propagate outer reduction. Each outer reduction is connected with its - // cached_gmem and output, since we added all the cached_gmem to the - // boundaryNodesSet, the transformation from one outer reduction can't - // propagate to other outer reductions due to the cutoff at - // boundaryNodesSet. Thus, we need a loop to initiate the propagation from - // each outer reduction. Don't allow parallelization propagation goes - // through cached_gmem, see issue 246. - for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { - const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( - {outer_reduction_tvs[i]}, {cached_gmem[i]}); - reduction_scheduler_utils::propagateTransformation( - outer_reference_tvs[i], boundaryNodesSet); - const auto& unroll_vectorizable_cached_tvs = - reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( - outer_reference_tvs[i], - is_vectorize, - cached_inputs, - cached_outputs); - reduction_scheduler_utils::propagateParallelization( - outer_reduction_tvs[i], - outer_reference_tvs[i], - is_unroll_or_vectorization, - is_outer_grid_persistence, - outer_reduction_tvs, - unroll_vectorizable_cached_tvs, - {selected_tvs_outer.begin(), selected_tvs_outer.end()}); - } - - // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write - // is guaranteed to be smaller or equal to input vectorization factor. - if (rparams->vectorization_factor_tmp_gmem_write > 1) { - for (auto tv : cached_gmem) { - NVF_ERROR( - rparams->vectorization_factor_tmp_gmem_write <= - rparams->unroll_factor_inner_reduction, - "vectorization factor of temp gmem write should be smaller than that of inner reduction.") - if (rparams->vectorization_factor_tmp_gmem_write < - rparams->unroll_factor_inner_reduction) { - tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); - } - tv->axis(-1)->parallelize(ParallelType::Vectorize); - } - } - // vectorization propagate through propagateParallelization only works for - // input and output tensors. propagate vectorization to cached_gmem_reload - // directly from output tv using parallelizeAllLike. must propagate - // seperaely for different tvs as outer reductions are transformed - // seperately. - if (rparams->vectorization_factor_outer > 1) { - for (auto tv : cached_gmem_reload) { - auto output_tvs = ir_utils::outputTvsOf(tv); - NVF_ERROR( - !output_tvs.empty(), - "cached_gmem_reload should have at least one output tensor.") - scheduler_utils::parallelizeAllLike( - output_tvs[0], - -1, - {cached_gmem_reload.begin(), cached_gmem_reload.end()}, - {ParallelType::Vectorize}); - } - } - - // Needs special handling of vectorized loading from shared memory due to - // potential different data types of inputs and shared memory tensor. - if (is_vectorize) { - reduction_scheduler_utils::sharedMemoryConsumerVectorization( - smem_consumers, rparams->unroll_factor_inner_reduction); - } - - // Remove dummy outputs as they can inadvertently affect CA positions - for (auto output : dummy_outputs) { - fusion->removeOutput(output); - } - inlineMost(); -} - } // namespace bool InnerOuterPersistentKernelScheduler::canScheduleCompileTime( @@ -1938,14 +354,15 @@ bool InnerOuterPersistentKernelScheduler::canScheduleRunTime( scheduler_hyperparameters_entry.get(); // check if there is enough register and shared memory for persistence - const auto buffer_params = getPersistentBufferStorageParams( - fusion, - runtime_info, - data_cache, - reduction_tvs, - hp.vectorize_factor, - hp.threads_per_block_min, - hp.threads_per_block_max); + const auto buffer_params = + inner_outer_utils::getPersistentBufferStorageParams( + fusion, + runtime_info, + data_cache, + reduction_tvs, + hp.vectorize_factor, + hp.threads_per_block_min, + hp.threads_per_block_max); const int64_t device_multiprocessor_count = (int64_t)at::cuda::getCurrentDeviceProperties()->multiProcessorCount; @@ -2014,9 +431,9 @@ void InnerOuterPersistentKernelScheduler::schedule( "Incorrect parameters sent to InnerOuterPersistentKernelScheduler::schedule", params); if (rparams->tma_warp_specialized) { - scheduleTmaWarpSpecializedInnerOuter(fusion, rparams); + inner_outer_tma_warp_specialized::scheduleFusion(fusion, rparams); } else { - scheduleInnerOuterPersistentKernel(fusion, rparams); + inner_outer_multi_wave::scheduleFusion(fusion, rparams); } } } // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_multi_wave.cpp b/csrc/scheduler/normalization_inner_outer_multi_wave.cpp new file mode 100644 index 00000000000..6ab87e871bf --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_multi_wave.cpp @@ -0,0 +1,717 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include + +namespace nvfuser { +namespace inner_outer_multi_wave { +// The innerOuterPersistentHeuristic is tuned for layer_norm backward on A100 +// ======= Method if hidden_size > 1024 ======= +// (1) Inner reduction is one reduction per block. Reduction domain is +// parallelized by TIDx and TIDy, Iteration domain is parallelized by BIDy. +// (2) Outer reduction is done in two-steps. The first step is partial +// reduction, reduction domain is parallelized by BIDy, iteration domain is +// parallelized by TIDx and TIDy. The partial results are written to gmem +// followed by a grid sync. The second step is block reduction, the reduction +// domain is parallelized by TIDy, the iteration domain is parallelized by TIDx +// and BIDy. +// ======= Method if hidden_size <= 1024 ======= +// (1) Inner reduction is multi-reductions per blocks. Reduction domain is +// parallelized by TIDx, Iteration domain is parallelized by BIDy and TIDy. +// (2) Outer reduction is same to cases where hidden_size > 1024 except the +// second step where in this case, the reduction domain is parallelized by TIDx +// and the iteration domain is parallelized by TIDy and BIDy. This switch +// between TIDx and TIDy is because: +// (a) We can do warp reduction with TIDx +// (b) TIDx*BIDy is usually much larger than hidden_size, e.g. 128*216 = 1024*27 +// this means without switch only 1/27 of the threads is used. +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type) { + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t device_multiprocessor_count = + (int64_t)dev_prop->multiProcessorCount; + // Parameters for inner reduction: + // Reduction dim: inner_vect, inner_batch, bdimx and bdimy + // Iteration dim: gdimy + + // Parameters for outer reduction: + // Reduction dim: bdimy + // Iteration dim: vectorization_factor_outer, bdimx, gdimy + struct InnerOuterParams { + int64_t inner_vect = -1; + int64_t inner_batch = -1; + int64_t bdimx = -1; + int64_t bdimy = -1; + int64_t bdimz = -1; + int64_t gdimy = -1; + int64_t tmp_gmem_write_vect = -1; + int64_t vectorization_factor_outer = -1; + int64_t threads_per_block = -1; + // derived metrics for sorting + int64_t warps_per_sm = -1; + int64_t required_register_per_thread = -1; + int64_t available_register_per_thread = -1; + + void verify() { + NVF_ERROR(inner_vect != -1, "inner_vect is not set."); + NVF_ERROR(inner_batch != -1, "inner_batch is not set."); + NVF_ERROR(bdimx != -1, "bdimx is not set."); + NVF_ERROR(bdimy != -1, "bdimy is not set."); + NVF_ERROR(gdimy != -1, "gdimy is not set."); + NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); + NVF_ERROR( + vectorization_factor_outer != -1, + "vectorization_factor_outer is not set."); + } + std::string toString() const { + std::stringstream ss; + ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch + << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz + << ", gdimy: " << gdimy + << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect + << ", vectorization_factor_outer: " << vectorization_factor_outer + << ", threads_per_block: " << threads_per_block + << ", warps_per_sm: " << warps_per_sm + << ", required_register_per_thread: " << required_register_per_thread + << ", available_register_per_thread: " + << available_register_per_thread; + return ss.str(); + } + }; + + // Set a minimum workload for each thread to take advantage of low + // intra-threads communication cost. + // Tuned for layer_norm backward on A100, still works fine on H100. + auto getMinimumBatch = [&]() -> int64_t { + if (inner_dim_numel >= 3072l) { + if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { + return 3l; + } else { + return 4l; + } + } else if (inner_dim_numel >= 2048l) { + return 2l; + } + return 1l; + }; + + // Estimate register usage per thread based on buffer size. + // Assuming a constant register overhead for non-buffer related usage, + // and all the register buffers are stored in registers. + auto getEstimatedRegisterUsage = [&](int64_t batch_mul_vect) { + int64_t persistent_buffer_size = + regs_buffer_size / inner_dim_numel * batch_mul_vect; + int64_t estimated_register_count = + persistent_buffer_size / scheduler_utils::bytes_per_register + + scheduler_utils::register_overhead; + return std::min( + estimated_register_count, scheduler_utils::max_registers_per_thread); + }; + + // Estimate max blocks per sm based on register and shared memory usage. + auto getBlocksPerSM = [&](const int64_t threads_per_sm, + const int64_t threads_per_block, + const int64_t warp_size) { + // check register limitation on blocks per sm + constexpr int64_t warp_allocation_granularity = 4; + const int64_t allocated_warps_per_block = + ceilDiv( + ceilDiv(threads_per_block, warp_size), + warp_allocation_granularity) * + warp_allocation_granularity; + int64_t max_blocks_per_sm_regs = scheduler_utils::safeDiv( + threads_per_sm / warp_size, allocated_warps_per_block); + // check shared memory limitation on blocks per sm + int64_t max_blocks_per_sm_smem = + (int64_t)dev_prop->sharedMemPerMultiprocessor / + (smem_overhead + smem_buffer_size); + return std::min(max_blocks_per_sm_regs, max_blocks_per_sm_smem); + }; + + // In the inner reduction part of the kernel, gdimy is used to parallelize the + // outer dimension. The kernel is a cooperative kernel, so the number of + // blocks should be as large as possible to achieve a high occupancy unless + // outer dim is too small which may lead large workload for the final outer + // reduction. So, gdimy is drvied from the number of blocks per sm and limited + // to ensure at least 8 rows per block. + // TODO: re-evaluate this 8 rows per block requirement. + auto getGdimy = [&](int64_t inner_vect, + int64_t threads_per_block, + int64_t inner_batch) { + int64_t reg_per_thread = + getEstimatedRegisterUsage(inner_vect * inner_batch); + int64_t threads_per_sm = getThreadsPerSMGivenRegPerThread(reg_per_thread); + int64_t blocks_per_sm = + getBlocksPerSM(threads_per_sm, threads_per_block, dev_prop->warpSize); + int64_t gdimy = blocks_per_sm * device_multiprocessor_count; + const int64_t outer_iter_min = 8; + const int64_t gdimy_max = scheduler_utils::roundUpToN( + ceilDiv(outer_dim_numel, outer_iter_min), device_multiprocessor_count); + while (gdimy > gdimy_max && blocks_per_sm > 1) { + blocks_per_sm -= 1; + gdimy = blocks_per_sm * device_multiprocessor_count; + } + return gdimy; + }; + + // The inner reduction part of the kernel also does a partial outer reduction + // and stores the partial results in tmp gmem and then reloaded to finish the + // outer reduciton. This function set the vectorization factor for write and + // and read of the partial outer reduction result. + // For write to tmp gmem, follows vectorization factor of inner reduction + // but don't exceed 16 bytes. + // For read from tmp gmem, since the paralelization is changed, a different + // vectorization factor is used to optimize the + // number of reaductions per thread. + auto getOuterReductionBufferVectFactor = [&](int64_t inner_vect) { + constexpr int64_t max_gmem_vect_access_bytes = 16; + const int64_t max_tmp_gmem_vect_factor = std::min( + max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); + int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; + const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; + int64_t vectorization_factor_outer = + std::min(workload_per_thread, max_tmp_gmem_vect_factor); + return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); + }; + + // In the outer reduction part of the kernel, inner and outer dims are + // parallelized as: + // --- inner dim: vect, bdimx, gdimy ---- + // --- outer dim: bdimy ----------------- + // This function splits the threads_per_block into bdimx and bdimy using: + // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) + // bdimy = threads_per_block / bdimx + auto getBdimxBdimy = [&](int64_t threads_per_block, + int64_t vectorization_factor_outer, + int64_t gdimy) { + // For widely used hidden sizes, threads_per_block has factor of 8, roundup + // to increase the probability of bdimx * bdimy == threads_per_block. + int64_t bdimx = scheduler_utils::roundUpPow2Or8( + ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); + // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. + // increase bdimx to make it divisible. Under worst case, bdimx equals to + // threads_per_block. + while (threads_per_block % bdimx) { + bdimx = std::min(bdimx + 8, threads_per_block); + } + // Set OuterParams Reduction dim: bdimy. + int64_t bdimy = threads_per_block / bdimx; + NVF_ERROR( + bdimy * bdimx == threads_per_block, + " threads_per_block must be divisible by bdimx and bdimy."); + return std::make_pair(bdimx, bdimy); + }; + + // Get the heuristics given vectorization factor and threads per block + auto getHeuristicsGivenVectThreads = [&](int64_t vect_factor, + int64_t threads_per_block) { + InnerOuterParams iop; + // (1) inner reduction + // Reduction dim: inner_batch, threads_per_block, vect_factor + // Iteration dim: gdimy + iop.inner_vect = vect_factor; + iop.threads_per_block = threads_per_block; + iop.inner_batch = + ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); + iop.gdimy = + getGdimy(iop.inner_vect, iop.threads_per_block, iop.inner_batch); + // (2) outer reduction + // Iteration dim: gdimy, bdimx, vectorization_factor_outer + // Reduction dim: bdimy + std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = + getOuterReductionBufferVectFactor(iop.inner_vect); + auto [bdimx, bdimy] = getBdimxBdimy( + threads_per_block, iop.vectorization_factor_outer, iop.gdimy); + iop.bdimx = bdimx; + iop.bdimy = bdimy; + // (3) Derived metrics warps_per_sm and register usage for sorting + iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + iop.required_register_per_thread = + getEstimatedRegisterUsage(iop.inner_vect * iop.inner_batch); + return iop; + }; + + // Use the maximum vectorization factor + const int64_t vect_factor = (int64_t)vectorize_factor; + + // Set a reasonable range for threads per block based on the number of + // elements in the inner dimension after vectorization. + // Start from 128 or a smaller number if inner dim is small. + const int64_t after_vect = inner_dim_numel / vect_factor; + const int64_t batch_min = getMinimumBatch(); + int64_t threads_per_block_min = hp_threads_per_block_min; + threads_per_block_min = std::min(threads_per_block_min, after_vect); + threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); + + // star max threads per block from min threads per block + int64_t threads_per_block_max = threads_per_block_min; + // increase to cover the whole inner dim + threads_per_block_max = + std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); + // round up to power of 2 + threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); + // don't go beyond the maximum threads per block + threads_per_block_max = + std::min(threads_per_block_max, hp_threads_per_block_max); + + // Store all the possible heuristics based on different threads per block. + // Vectorizaton is fixed at the maximum value. + std::vector iop_candidates; + for (auto threads_per_block = threads_per_block_max; + threads_per_block >= threads_per_block_min; + threads_per_block /= 2) { + iop_candidates.emplace_back( + getHeuristicsGivenVectThreads(vect_factor, threads_per_block)); + } + + // Sort the heuristics based on the register usage and occupancy. + std::stable_sort( + iop_candidates.begin(), + iop_candidates.end(), + [](const InnerOuterParams& a, const InnerOuterParams& b) { + // If a thread can use more registers than required, there is a high + // chance that it can avoid register spilling and compiler can optimize + // for better instruction level parallelism. + int64_t extra_regs_a = + a.available_register_per_thread - a.required_register_per_thread; + int64_t extra_regs_b = + b.available_register_per_thread - b.required_register_per_thread; + if (extra_regs_a > 0 && extra_regs_b < 0) { + return true; + } else if (extra_regs_a < 0 && extra_regs_b > 0) { + return false; + } + // High occupancy provides better threads level parallelism. + // 25% is sufficient since ILP is high due to persistent batch sizes + // which is equivalent to unrolling inner dim. + if (a.warps_per_sm != b.warps_per_sm && + (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { + return a.warps_per_sm > b.warps_per_sm; + } + // Tie breaker, smaller threads_per_block to reduce communication + // overhead + return a.threads_per_block < b.threads_per_block; + }); + + // Pick the best heuristic + auto iop = iop_candidates.front(); + + // Special case, when inner_dim_numel <= 1024, bdimx is usually small + // after divide by inner_vect and inner_batch. In this case, bdimy is used to + // parallelize outer_dim instead of inner_dim. This pattern is named multi + // reductions per block (mrpb). + if (inner_dim_numel <= 1024) { + rparams->multiple_reds_per_blk = true; + rparams->tidx_for_outer_reduction = true; + + // Step-1, InnerParams, Reduction dim: inner_vect(reuse), + // inner_batch(reuse), bdimx + iop.bdimx = ceilDiv(inner_dim_numel, iop.inner_vect * iop.inner_batch); + + // Step-2, InnerParams, Iteration dim: gdimy, bdimy (in next step) + iop.gdimy = getGdimy(iop.inner_vect, iop.bdimx, iop.inner_batch); + + // Step-3, OuterParams, Iteration dim: vectorization_factor_outer(reuse), + // bdimy, gdimy (in previous step). + // WAR for https://github.com/NVIDIA/Fuser/issues/3428 + iop.bdimy = 1; + + // Step-4, OuterParams, Reduction dim: bdimx (already done) + iop.warps_per_sm = ceilDiv(iop.bdimx * iop.bdimy, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + + if (iop.bdimx % dev_prop->warpSize == 0) { + rparams->pad_inner_reduction_to_warp = true; + rparams->pad_outer_reduction_to_warp = true; + } + rparams->block_dim_iter_dom = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimy * iop.gdimy < + inner_dim_numel; + } else { + rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < + inner_dim_numel; + rparams->static_bdimx = true; + rparams->static_bdimy = true; + iop.bdimz = ceilDiv( + ceilDiv( + ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), + iop.inner_batch); + NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); + } + + // check all the parameters in InnerOuterParams are set. + iop.verify(); + + rparams->persistent_kernel = true; + rparams->fastest_dim = true; + rparams->combined_inner_outer = true; + // tmp_gmem is the intermediate result of outer reduction, its dtype is float, + // so the maximum vectorization factor is 4. + rparams->vectorization_factor_outer = iop.vectorization_factor_outer; + rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; + rparams->cparams.maxrregcount = iop.available_register_per_thread; + rparams->unroll_factor_inner_reduction = iop.inner_vect; + rparams->batches_per_block_inner_reduction = iop.inner_batch; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->vectorize_inner_reduction = iop.inner_vect > 1; + rparams->split_grid_dim_iter_dom_outer = true; + rparams->grid_dim_iter_dom = ParallelType::BIDy; + + rparams->lparams = LaunchParams( + LaunchParams::UNINITIALIZED_VAL, + iop.gdimy, + LaunchParams::UNINITIALIZED_VAL, + iop.bdimx, + iop.bdimy, + LaunchParams::UNINITIALIZED_VAL); + + if (!rparams->smem_persistent_buffers.empty()) { + rparams->tag = + "InnerOuter Register and Shared Memory Persistent Heuristic.\n"; + } else { + rparams->tag = "InnerOuter Register Persistent Heuristic.\n"; + } + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" + << "outer_dim_numel: " << outer_dim_numel << "\n" + << "inner_dim_numel: " << inner_dim_numel << "\n" + << "regs_buffer_size: " << regs_buffer_size << "\n" + << "smem_buffer_size: " << smem_buffer_size << "\n" + << "smem_overhead: " << smem_overhead << "\n" + << "vectorize_factor_input: " << iop.inner_vect << "\n" + << "vectorization_factor_tmp_gmem_write: " + << iop.tmp_gmem_write_vect << "\n" + << "vectorization_factor_outer: " << iop.vectorization_factor_outer + << "\n" + << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk + << "\n" + << "warps_per_sm: " << iop.warps_per_sm << "\n" + << "gdimy: " << iop.gdimy << "\n" + << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; + debug() << rparams->toString() << std::endl; + } +} + +void scheduleOuterReduction( + Fusion* fusion, + const ReductionParams* rparams, + const std::vector& outer_reduction_tvs, + std::vector& cached_gmem, + std::vector& cached_gmem_reload, + std::vector& outer_reference_tvs, + std::unordered_set& boundaryNodesSet) { + auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { + int prev_i = -1; + for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { + if (mergeReduction == tv->axis(i)->isReduction()) { + if (prev_i == -1) { + prev_i = i; + } else { + tv->merge(i, prev_i); + prev_i = i; + } + } + } + }; + for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + + // merge tensorview to [reduction, iteraiton] domains + mergeReductionOrIterDomains(outer_reduction_tv, true); + mergeReductionOrIterDomains(outer_reduction_tv, false); + if (rparams->multiple_reds_per_blk) { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(rparams->block_dim_iter_dom)); + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(rparams->grid_dim_iter_dom), false); + } else { + outer_reduction_tv->split(0, rparams->lparams.gdimy()); + } + + if (rparams->multiple_reds_per_blk) { + outer_reduction_tv->rFactor({1}); + } + TensorView* partialResult = rparams->multiple_reds_per_blk + ? outer_reduction_tv->rFactor({1}) + : outer_reduction_tv->rFactor({0}); + partialResult->cacheBefore(); + partialResult->setMemoryType(MemoryType::Global); + TensorView* partialResultReload = partialResult->cacheAfter(); + + boundaryNodesSet.insert(partialResultReload); + cached_gmem.emplace_back(partialResult); + cached_gmem_reload.emplace_back(partialResultReload); + + if (rparams->multiple_reds_per_blk) { + if (rparams->tidx_for_outer_reduction) { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(ParallelType::TIDx)); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDx); + // to use warp reduction + if (rparams->pad_outer_reduction_to_warp) { + outer_reduction_tv->axis(1)->padToMultipleOfWarp(); + } + } else { + outer_reduction_tv->split( + 0, NamedScalar::getParallelDim(ParallelType::TIDy)); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + } + // iteration domain + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize( + ParallelType::Vectorize); + } + + if (rparams->tidx_for_outer_reduction) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::TIDy)); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDy); + } else { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::TIDx)); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + + } else { + // reduction domain + outer_reduction_tv->split(0, rparams->lparams.bdimy()); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + + // iteration domain + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize( + ParallelType::Vectorize); + } + + if (rparams->lparams.bdimx() > 1) { + outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + } + auto outer_reference_tv = + reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); + outer_reference_tvs.emplace_back(outer_reference_tv); + } +} + +// fusion is the input IR that will be modified by this function +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams) { + FusionGuard fg(fusion); + + // Grab the reduction, input, and output tensor views. dummy_outputs are + // helper tensors for persistent buffer projection. + std::vector dummy_outputs, cached_inputs, reduction_tvs, + smem_consumers; + std::vector> cached_outputs; + normalization_scheduler_utils::beforeSchedule( + fusion, + rparams, + dummy_outputs, + cached_inputs, + reduction_tvs, + smem_consumers, + cached_outputs); + + // split reduction_tvs into inner and outer reduction_tvs + std::vector inner_reduction_tvs, outer_reduction_tvs; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + inner_reduction_tvs.emplace_back(tv); + } else { + outer_reduction_tvs.emplace_back(tv); + } + } + NVF_ERROR( + !inner_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); + NVF_ERROR( + !outer_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); + + // schedule inner reduction, only schedule the first inner reduction tv, + // then will be propagated to other inner reduction tvs. + TensorView* inner_reference_tv = + normalization_scheduler_utils::scheduleReductionGeneral( + fusion, + rparams, + inner_reduction_tvs, + SchedulerType::InnerOuterPersistent); + + // schedule outer reduction, schedule all the outer reduction tvs since we + // need to store the intermediate results. + std::vector cached_gmem; + std::vector cached_gmem_reload; + std::vector outer_reference_tvs; + std::unordered_set boundaryNodesSet; + scheduleOuterReduction( + fusion, + rparams, + outer_reduction_tvs, + cached_gmem, + cached_gmem_reload, + outer_reference_tvs, + boundaryNodesSet); + + // Propagate inner reduction and outer reductions + for (auto output : dummy_outputs) { + fusion->addOutput(output); + } + + const bool is_unroll_or_vectorization = rparams->isUnrolled(); + const bool is_vectorize = + rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; + const bool is_outer_grid_persistence = rparams->persistent_kernel && + rparams->cross_grid_inner_reduction && !rparams->fastest_dim; + + // Propagate inner reduction. There is a cutoff at boundaryNodesSet, so this + // propagation will not propagate to the final outer reduction. + reduction_scheduler_utils::propagateTransformation( + inner_reference_tv, boundaryNodesSet); + reduction_scheduler_utils::propagateRFactor( + inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); + + // Don't allow parallelization propagation goes through boundaryNodesSet + const auto& selected_tvs_inner = + scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); + reduction_scheduler_utils::propagateParallelization( + inner_reduction_tvs[0], + inner_reference_tv, + is_unroll_or_vectorization, + is_outer_grid_persistence, + inner_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_inner.begin(), selected_tvs_inner.end()}); + + // Propagate outer reduction. Each outer reduction is connected with its + // cached_gmem and output, since we added all the cached_gmem to the + // boundaryNodesSet, the transformation from one outer reduction can't + // propagate to other outer reductions due to the cutoff at + // boundaryNodesSet. Thus, we need a loop to initiate the propagation from + // each outer reduction. Don't allow parallelization propagation goes + // through cached_gmem, see issue 246. + for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { + const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( + {outer_reduction_tvs[i]}, {cached_gmem[i]}); + reduction_scheduler_utils::propagateTransformation( + outer_reference_tvs[i], boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + outer_reference_tvs[i], + is_vectorize, + cached_inputs, + cached_outputs); + reduction_scheduler_utils::propagateParallelization( + outer_reduction_tvs[i], + outer_reference_tvs[i], + is_unroll_or_vectorization, + is_outer_grid_persistence, + outer_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_outer.begin(), selected_tvs_outer.end()}); + } + + // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write + // is guaranteed to be smaller or equal to input vectorization factor. + if (rparams->vectorization_factor_tmp_gmem_write > 1) { + for (auto tv : cached_gmem) { + NVF_ERROR( + rparams->vectorization_factor_tmp_gmem_write <= + rparams->unroll_factor_inner_reduction, + "vectorization factor of temp gmem write should be smaller than that of inner reduction.") + if (rparams->vectorization_factor_tmp_gmem_write < + rparams->unroll_factor_inner_reduction) { + tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); + } + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + } + // vectorization propagate through propagateParallelization only works for + // input and output tensors. propagate vectorization to cached_gmem_reload + // directly from output tv using parallelizeAllLike. must propagate + // seperaely for different tvs as outer reductions are transformed + // seperately. + if (rparams->vectorization_factor_outer > 1) { + for (auto tv : cached_gmem_reload) { + auto output_tvs = ir_utils::outputTvsOf(tv); + NVF_ERROR( + !output_tvs.empty(), + "cached_gmem_reload should have at least one output tensor.") + scheduler_utils::parallelizeAllLike( + output_tvs[0], + -1, + {cached_gmem_reload.begin(), cached_gmem_reload.end()}, + {ParallelType::Vectorize}); + } + } + + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + + // Remove dummy outputs as they can inadvertently affect CA positions + for (auto output : dummy_outputs) { + fusion->removeOutput(output); + } + inlineMost(); +} +} // namespace inner_outer_multi_wave +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_multi_wave.h b/csrc/scheduler/normalization_inner_outer_multi_wave.h new file mode 100644 index 00000000000..3b373189246 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_multi_wave.h @@ -0,0 +1,31 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +namespace inner_outer_multi_wave { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type); + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams); +} // namespace inner_outer_multi_wave +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_tma_ws.cpp b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp new file mode 100644 index 00000000000..0da6b734435 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_tma_ws.cpp @@ -0,0 +1,647 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include + +#include +namespace nvfuser { +namespace inner_outer_tma_warp_specialized { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type) { + rparams->tma_warp_specialized = true; + rparams->project_persistent_buffers = project_to_input; + rparams->cparams.index_type = index_type; + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + const int64_t device_multiprocessor_count = + (int64_t)dev_prop->multiProcessorCount; + // Parameters for inner reduction: + // Reduction dim: inner_vect, inner_batch, bdimx and bdimy + // Iteration dim: gdimy + + // Parameters for outer reduction: + // Reduction dim: bdimy + // Iteration dim: vectorization_factor_outer, bdimx, gdimy + struct InnerOuterParams { + int64_t inner_vect = -1; + int64_t inner_batch = -1; + int64_t bdimx = -1; + int64_t bdimy = -1; + int64_t bdimz = -1; + int64_t gdimy = -1; + int64_t tmp_gmem_write_vect = -1; + int64_t vectorization_factor_outer = -1; + int64_t threads_per_block = -1; + // derived metrics for sorting + int64_t warps_per_sm = -1; + int64_t required_register_per_thread = -1; + int64_t available_register_per_thread = -1; + + void verify() { + NVF_ERROR(inner_vect != -1, "inner_vect is not set."); + NVF_ERROR(inner_batch != -1, "inner_batch is not set."); + NVF_ERROR(bdimx != -1, "bdimx is not set."); + NVF_ERROR(bdimy != -1, "bdimy is not set."); + NVF_ERROR(gdimy != -1, "gdimy is not set."); + NVF_ERROR(tmp_gmem_write_vect != -1, "tmp_gmem_write_vect is not set."); + NVF_ERROR( + vectorization_factor_outer != -1, + "vectorization_factor_outer is not set."); + } + std::string toString() const { + std::stringstream ss; + ss << "inner_vect: " << inner_vect << ", inner_batch: " << inner_batch + << ", bdimx: " << bdimx << ", bdimy: " << bdimy << ", bdimz: " << bdimz + << ", gdimy: " << gdimy + << ", tmp_gmem_write_vect: " << tmp_gmem_write_vect + << ", vectorization_factor_outer: " << vectorization_factor_outer + << ", threads_per_block: " << threads_per_block + << ", warps_per_sm: " << warps_per_sm + << ", required_register_per_thread: " << required_register_per_thread + << ", available_register_per_thread: " + << available_register_per_thread; + return ss.str(); + } + }; + + // Set a minimum workload for each thread to take advantage of low + // intra-threads communication cost. + // Tuned for layer_norm backward on A100, still works fine on H100. + auto get_minimum_batch = [&]() -> int64_t { + if (inner_dim_numel >= 3072l) { + if (outer_dim_numel <= 2048l && inner_dim_numel == 3072l) { + return 3l; + } else { + return 4l; + } + } else if (inner_dim_numel >= 2048l) { + return 2l; + } + return 1l; + }; + + // Estimate register usage per thread based on buffer size. + // Assuming a constant register overhead for non-buffer related usage, + // and all the register buffers are stored in registers. + auto get_estimated_register_usage = [&](int64_t batch_mul_vect) { + int64_t persistent_buffer_size = + regs_buffer_size / inner_dim_numel * batch_mul_vect; + int64_t estimated_register_count = + persistent_buffer_size / scheduler_utils::bytes_per_register + + scheduler_utils::register_overhead; + return std::min( + estimated_register_count, scheduler_utils::max_registers_per_thread); + }; + + // The inner reduction part of the kernel also does a partial outer reduction + // and stores the partial results in tmp gmem and then reloaded to finish the + // outer reduciton. This function set the vectorization factor for write and + // and read of the partial outer reduction result. + // For write to tmp gmem, follows vectorization factor of inner reduction + // but don't exceed 16 bytes. + // For read from tmp gmem, since the paralelization is changed, a different + // vectorization factor is used to optimize the + // number of reaductions per thread. + auto get_outer_reduction_buffer_vect_factor = [&](int64_t inner_vect) { + constexpr int64_t max_gmem_vect_access_bytes = 16; + const int64_t max_tmp_gmem_vect_factor = std::min( + max_gmem_vect_access_bytes / (int64_t)tmp_gmem_dtype_size, inner_vect); + int64_t tmp_gmem_write_vect = max_tmp_gmem_vect_factor; + const int64_t workload_per_thread = inner_dim_numel >= 4096 ? 4l : 2l; + int64_t vectorization_factor_outer = + std::min(workload_per_thread, max_tmp_gmem_vect_factor); + return std::make_pair(tmp_gmem_write_vect, vectorization_factor_outer); + }; + + // In the outer reduction part of the kernel, inner and outer dims are + // parallelized as: + // --- inner dim: vect, bdimx, gdimy ---- + // --- outer dim: bdimy ----------------- + // This function splits the threads_per_block into bdimx and bdimy using: + // bdimx = ceilDiv(inner_dim_numel / vect, gdimy) + // bdimy = threads_per_block / bdimx + auto get_bdimx_bdimy = [&](int64_t threads_per_block, + int64_t vectorization_factor_outer, + int64_t gdimy) { + // For widely used hidden sizes, threads_per_block has factor of 8, roundup + // to increase the probability of bdimx * bdimy == threads_per_block. + int64_t bdimx = scheduler_utils::roundUpPow2Or8( + ceilDiv(inner_dim_numel / vectorization_factor_outer, gdimy)); + // if still not divisible, e.g. threads_per_block = 256, bdimx = 40. + // increase bdimx to make it divisible. Under worst case, bdimx equals to + // threads_per_block. + while (threads_per_block % bdimx) { + bdimx = std::min(bdimx + 8, threads_per_block); + } + // Set OuterParams Reduction dim: bdimy. + int64_t bdimy = threads_per_block / bdimx; + NVF_ERROR( + bdimy * bdimx == threads_per_block, + " threads_per_block must be divisible by bdimx and bdimy."); + return std::make_pair(bdimx, bdimy); + }; + + // Get the heuristics given vectorization factor and threads per block + auto get_heuristics_given_vect_threads = [&](int64_t vect_factor, + int64_t threads_per_block) { + InnerOuterParams iop; + // (1) inner reduction + // Reduction dim: inner_batch, threads_per_block, vect_factor + // Iteration dim: gdimy + iop.inner_vect = vect_factor; + iop.threads_per_block = threads_per_block; + iop.inner_batch = + ceilDiv(inner_dim_numel / iop.inner_vect, iop.threads_per_block); + iop.gdimy = device_multiprocessor_count; + + // (2) outer reduction + // Iteration dim: gdimy, bdimx, vectorization_factor_outer + // Reduction dim: bdimy + std::tie(iop.tmp_gmem_write_vect, iop.vectorization_factor_outer) = + get_outer_reduction_buffer_vect_factor(iop.inner_vect); + auto [bdimx, bdimy] = get_bdimx_bdimy( + threads_per_block, iop.vectorization_factor_outer, iop.gdimy); + iop.bdimx = bdimx; + iop.bdimy = bdimy; + // (3) Derived metrics warps_per_sm and register usage for sorting + iop.warps_per_sm = ceilDiv(iop.threads_per_block, dev_prop->warpSize) * + iop.gdimy / device_multiprocessor_count; + iop.available_register_per_thread = + getRegPerThreadGivenThreadsPerSM(dev_prop->warpSize * iop.warps_per_sm); + iop.required_register_per_thread = + get_estimated_register_usage(iop.inner_vect * iop.inner_batch); + return iop; + }; + + // Use the maximum vectorization factor + const int64_t vect_factor = (int64_t)vectorize_factor; + + // Set a reasonable range for threads per block based on the number of + // elements in the inner dimension after vectorization. + // Start from 128 or a smaller number if inner dim is small. + const int64_t after_vect = inner_dim_numel / vect_factor; + const int64_t batch_min = get_minimum_batch(); + int64_t threads_per_block_min = hp_threads_per_block_min; + threads_per_block_min = std::min(threads_per_block_min, after_vect); + threads_per_block_min = scheduler_utils::roundUpPow2(threads_per_block_min); + + // star max threads per block from min threads per block + int64_t threads_per_block_max = threads_per_block_min; + // increase to cover the whole inner dim + threads_per_block_max = + std::max(threads_per_block_max, ceilDiv(after_vect, batch_min)); + // round up to power of 2 + threads_per_block_max = scheduler_utils::roundUpPow2(threads_per_block_max); + // don't go beyond the maximum threads per block + threads_per_block_max = + std::min(threads_per_block_max, hp_threads_per_block_max); + + // Store all the possible heuristics based on different threads per block. + // Vectorizaton is fixed at the maximum value. + std::vector iop_candidates; + for (auto threads_per_block = threads_per_block_max; + threads_per_block >= threads_per_block_min; + threads_per_block /= 2) { + iop_candidates.emplace_back( + get_heuristics_given_vect_threads(vect_factor, threads_per_block)); + } + + // Sort the heuristics based on the register usage and occupancy. + std::stable_sort( + iop_candidates.begin(), + iop_candidates.end(), + [](const InnerOuterParams& a, const InnerOuterParams& b) { + // If a thread can use more registers than required, there is a high + // chance that it can avoid register spilling and compiler can optimize + // for better instruction level parallelism. + int64_t extra_regs_a = + a.available_register_per_thread - a.required_register_per_thread; + int64_t extra_regs_b = + b.available_register_per_thread - b.required_register_per_thread; + if (extra_regs_a > 0 && extra_regs_b < 0) { + return true; + } else if (extra_regs_a < 0 && extra_regs_b > 0) { + return false; + } + // High occupancy provides better threads level parallelism. + // 25% is sufficient since ILP is high due to persistent batch sizes + // which is equivalent to unrolling inner dim. + if (a.warps_per_sm != b.warps_per_sm && + (a.warps_per_sm < 16 || b.warps_per_sm < 16)) { + return a.warps_per_sm > b.warps_per_sm; + } + // Tie breaker, smaller threads_per_block to reduce communication + // overhead + return a.threads_per_block < b.threads_per_block; + }); + + // Pick the best heuristic + auto iop = iop_candidates.front(); + rparams->block_dim_inner_reduction_extra = ParallelType::TIDy; + rparams->combined_split_grid_inner_dim = + iop.vectorization_factor_outer * iop.bdimx * iop.gdimy < inner_dim_numel; + rparams->static_bdimx = true; + rparams->static_bdimy = true; + iop.bdimz = ceilDiv( + ceilDiv(ceilDiv(inner_dim_numel / iop.inner_vect, iop.bdimx), iop.bdimy), + iop.inner_batch); + NVF_ERROR(iop.bdimz == 1, "bdimz must be 1."); + + // check all the parameters in InnerOuterParams are set. + iop.verify(); + + rparams->persistent_kernel = true; + rparams->fastest_dim = true; + rparams->combined_inner_outer = true; + // tmp_gmem is the intermediate result of outer reduction, its dtype is float, + // so the maximum vectorization factor is 4. + rparams->vectorization_factor_outer = iop.vectorization_factor_outer; + rparams->vectorization_factor_tmp_gmem_write = iop.tmp_gmem_write_vect; + rparams->cparams.maxrregcount = iop.available_register_per_thread; + rparams->unroll_factor_inner_reduction = iop.inner_vect; + rparams->batches_per_block_inner_reduction = iop.inner_batch; + rparams->block_dim_inner_reduction = ParallelType::TIDx; + rparams->vectorize_inner_reduction = iop.inner_vect > 1; + rparams->split_grid_dim_iter_dom_outer = true; + rparams->grid_dim_iter_dom = ParallelType::BIDy; + + rparams->lparams = LaunchParams( + LaunchParams::UNINITIALIZED_VAL, + iop.gdimy, + LaunchParams::UNINITIALIZED_VAL, + iop.bdimx, + iop.bdimy, + LaunchParams::UNINITIALIZED_VAL); + + rparams->tag = "TMA Warp Specialized Persistent Heuristic.\n"; + + if (isDebugDumpEnabled(DebugDumpOption::SchedulerDebug)) { + debug() << "\n===== Combined InnerOuter Reduction Stats ========\n" + << "outer_dim_numel: " << outer_dim_numel << "\n" + << "inner_dim_numel: " << inner_dim_numel << "\n" + << "regs_buffer_size: " << regs_buffer_size << "\n" + << "smem_buffer_size: " << smem_buffer_size << "\n" + << "smem_overhead: " << smem_overhead << "\n" + << "vectorize_factor_input: " << iop.inner_vect << "\n" + << "vectorization_factor_tmp_gmem_write: " + << iop.tmp_gmem_write_vect << "\n" + << "vectorization_factor_outer: " << iop.vectorization_factor_outer + << "\n" + << "multiple_reds_per_blk: " << rparams->multiple_reds_per_blk + << "\n" + << "warps_per_sm: " << iop.warps_per_sm << "\n" + << "gdimy: " << iop.gdimy << "\n" + << "block(" << (iop.bdimx) << ", " << iop.bdimy << ", " << 1 << ")"; + debug() << rparams->toString() << std::endl; + } +} + +void scheduleOuterReduction( + Fusion* fusion, + const ReductionParams* rparams, + const std::vector& outer_reduction_tvs, + std::vector& cached_gmem, + std::vector& cached_gmem_reload, + std::vector& outer_reference_tvs, + std::unordered_set& boundaryNodesSet) { + auto mergeReductionOrIterDomains = [](TensorView* tv, bool mergeReduction) { + int prev_i = -1; + for (int i = static_cast(tv->nDims()) - 1; i >= 0; i--) { + if (mergeReduction == tv->axis(i)->isReduction()) { + if (prev_i == -1) { + prev_i = i; + } else { + tv->merge(i, prev_i); + prev_i = i; + } + } + } + }; + for (auto& outer_reduction_tv : outer_reduction_tvs) { + // Similar to the inner reduction, we need to reorder the outer reduction tv + // when there are view operations. + if (!ir_utils::getViewOps(fusion).empty()) { + // Reorder reference_tv after propagating the view operation. This will + // reorder for better merging. + outer_reduction_tv->reorder( + scheduler_utils::domainReorderAsLogicalMap(outer_reduction_tv)); + } + + // merge tensorview to [reduction, iteraiton] domains + mergeReductionOrIterDomains(outer_reduction_tv, true); + mergeReductionOrIterDomains(outer_reduction_tv, false); + + // First-stage of outer reduction + outer_reduction_tv->split(0, rparams->lparams.gdimy()); + + TensorView* partialResult = outer_reduction_tv->rFactor({0}); + partialResult->cacheBefore(); + partialResult->setMemoryType(MemoryType::Global); + TensorView* partialResultReload = partialResult->cacheAfter(); + + boundaryNodesSet.insert(partialResultReload); + cached_gmem.emplace_back(partialResult); + cached_gmem_reload.emplace_back(partialResultReload); + + // Second-stage of outer reduction + // reduction domain, [I1/TIDy, TIDy] + outer_reduction_tv->split(0, rparams->lparams.bdimy()); + outer_reduction_tv->axis(1)->parallelize(ParallelType::TIDy); + // iteration domain, [BIDy, TIDx, Vect] + int axisID = -1; + if (rparams->vectorization_factor_outer > 1) { + outer_reduction_tv->split(axisID, rparams->vectorization_factor_outer); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::Vectorize); + } + + if (rparams->lparams.bdimx() > 1) { + outer_reduction_tv->split(axisID, rparams->lparams.bdimx()); + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::TIDx); + } + + if (rparams->combined_split_grid_inner_dim) { + outer_reduction_tv->split( + axisID, NamedScalar::getParallelDim(ParallelType::BIDy)); + } + + outer_reduction_tv->axis(axisID--)->parallelize(ParallelType::BIDy); + + auto outer_reference_tv = + reduction_scheduler_utils::sortAndRFactor(outer_reduction_tv); + outer_reference_tvs.emplace_back(outer_reference_tv); + } +} + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams) { + FusionGuard fg(fusion); + + // Grab the reduction, input, and output tensor views. dummy_outputs are + // helper tensors for persistent buffer projection. + std::vector dummy_outputs, cached_inputs, reduction_tvs, + smem_consumers; + std::vector> cached_outputs; + normalization_scheduler_utils::beforeSchedule( + fusion, + rparams, + dummy_outputs, + cached_inputs, + reduction_tvs, + smem_consumers, + cached_outputs); + + // split reduction_tvs into inner and outer reduction_tvs + std::vector inner_reduction_tvs, outer_reduction_tvs; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + inner_reduction_tvs.emplace_back(tv); + } else { + outer_reduction_tvs.emplace_back(tv); + } + } + NVF_ERROR( + !inner_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no inner reduction is found."); + NVF_ERROR( + !outer_reduction_tvs.empty(), + "schedulePersistentKernelInnerOuter is called but no outer reduction is found."); + + // schedule inner reduction, only schedule the first inner reduction tv, + // then will be propagated to other inner reduction tvs. + TensorView* inner_reference_tv = + normalization_scheduler_utils::scheduleReductionGeneral( + fusion, + rparams, + inner_reduction_tvs, + SchedulerType::InnerOuterPersistent); + + // schedule outer reduction, schedule all the outer reduction tvs since we + // need to store the intermediate results. + std::vector cached_gmem; + std::vector cached_gmem_reload; + std::vector outer_reference_tvs; + std::unordered_set boundaryNodesSet; + scheduleOuterReduction( + fusion, + rparams, + outer_reduction_tvs, + cached_gmem, + cached_gmem_reload, + outer_reference_tvs, + boundaryNodesSet); + + // Propagate inner reduction and outer reductions + for (auto output : dummy_outputs) { + fusion->addOutput(output); + } + + // Collect tvs loaded with TMA, they require special scheduling. + std::vector tma_load_tvs; + if (rparams->tma_warp_specialized) { + for (auto tv : smem_consumers) { + auto smem_tv = ir_utils::getSoleProducerTv(tv); + if (std::find(tma_load_tvs.begin(), tma_load_tvs.end(), smem_tv) == + tma_load_tvs.end()) { + tma_load_tvs.emplace_back(smem_tv); + } + } + } + + const bool is_unroll_or_vectorization = rparams->isUnrolled(); + const bool is_vectorize = + rparams->vectorize_inner_reduction || rparams->vectorize_iter_dom; + const bool is_outer_grid_persistence = rparams->persistent_kernel && + rparams->cross_grid_inner_reduction && !rparams->fastest_dim; + + // Propagate transformations for inner reduction. + // Two steps are used since tma tvs are scheduled differently. + // Step-1, propagate iteration domain in inner reduction. + // Step-2, propagate reduction domain in inner reduction. + if (rparams->tma_warp_specialized) { + // Find the axis that splits the reduction domain and iteration domain. + int first_redu_axis = -1; + int n_dims = (int)inner_reference_tv->nDims(); + for (auto i = 0; i < n_dims; i++) { + if (inner_reference_tv->axis(i)->isReduction() || + inner_reference_tv->axis(i)->isRFactorProduct()) { + first_redu_axis = i; + break; + } + } + + // Step-1, propagate iteration domain in inner reduction. + // outer_reference_tvs are excluded since they are already scheduled + // with a different pattern for the final step of outer reduciton. + if (first_redu_axis > 0) { + TransformPropagator propagator(inner_reference_tv, first_redu_axis - 1); + std::vector all_tvs_except = ir_utils::allTvsExcept( + fusion, {outer_reference_tvs.begin(), outer_reference_tvs.end()}); + SetSelector selector({all_tvs_except.begin(), all_tvs_except.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } + + // Step-2, propagate reduction domain in inner reduction. + // (a) Tvs in boundaryNodesSet are excluded since they should follow outer + // reduction pattern. + // (b) TMA tvs are excluded since they require special scheduling. + // (3) Excluding tma tvs breaks the propagation path from inner reduction tv + // to cached_gmem which stores the results of the first-stage of outer + // reduction. The solution is adding a dummy output to link them. The same + // trick is used when projecting persistent buffers to inputs. + auto inner_reduction_input = + ir_utils::getSoleProducerTv(inner_reference_tv); + for (auto tv : cached_gmem) { + // T1(smem) --> T2 (l) --> T3 = OuterRedu(T2) --> T4(cached_gmem) + // outer_reduction_input: T2 + // partial_outer_redu_tv: T3 + auto partial_outer_redu_tv = ir_utils::getSoleProducerTv(tv); + auto outer_reduction_input = + ir_utils::getSoleProducerTv(partial_outer_redu_tv); + auto dummy_output = add(inner_reduction_input, outer_reduction_input); + fusion->addOutput(dummy_output); + dummy_outputs.emplace_back(dummy_output); + } + + // Tvs requiring special scheduling + std::unordered_set special_tvs{ + tma_load_tvs.begin(), tma_load_tvs.end()}; + for (auto tv : boundaryNodesSet) { + if (special_tvs.count(tv) == 0) { + special_tvs.emplace(tv); + } + } + TransformPropagator propagator(inner_reference_tv); + std::vector all_tvs_except_cache = ir_utils::allTvsExcept( + fusion, {special_tvs.begin(), special_tvs.end()}); + SetSelector selector( + {all_tvs_except_cache.begin(), all_tvs_except_cache.end()}); + MaxLogicalDomainInfoSpanningTree(inner_reference_tv, &selector) + .traverse(&propagator); + } else { + reduction_scheduler_utils::propagateTransformation( + inner_reference_tv, boundaryNodesSet); + } + reduction_scheduler_utils::propagateRFactor( + inner_reference_tv, inner_reduction_tvs[0], inner_reduction_tvs); + + // parallelization propagation + const auto& selected_tvs_inner = + scheduler_utils::getAllTvsFrom(inner_reduction_tvs, boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + inner_reference_tv, is_vectorize, cached_inputs, cached_outputs); + reduction_scheduler_utils::propagateParallelization( + inner_reduction_tvs[0], + inner_reference_tv, + is_unroll_or_vectorization, + is_outer_grid_persistence, + inner_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_inner.begin(), selected_tvs_inner.end()}); + + // Propagate outer reduction. Each outer reduction is connected with its + // cached_gmem and output, since we added all the cached_gmem to the + // boundaryNodesSet, the transformation from one outer reduction can't + // propagate to other outer reductions due to the cutoff at + // boundaryNodesSet. Thus, we need a loop to initiate the propagation from + // each outer reduction. Don't allow parallelization propagation goes + // through cached_gmem, see issue 246. + for (long unsigned int i = 0; i < outer_reference_tvs.size(); i++) { + const auto& selected_tvs_outer = scheduler_utils::getAllTvsFrom( + {outer_reduction_tvs[i]}, {cached_gmem[i]}); + reduction_scheduler_utils::propagateTransformation( + outer_reference_tvs[i], boundaryNodesSet); + const auto& unroll_vectorizable_cached_tvs = + reduction_scheduler_utils::getCachedTvsToUnrollOrVectorize( + outer_reference_tvs[i], + is_vectorize, + cached_inputs, + cached_outputs); + reduction_scheduler_utils::propagateParallelization( + outer_reduction_tvs[i], + outer_reference_tvs[i], + is_unroll_or_vectorization, + is_outer_grid_persistence, + outer_reduction_tvs, + unroll_vectorizable_cached_tvs, + {selected_tvs_outer.begin(), selected_tvs_outer.end()}); + } + + // Up to this point, the outer dimension of the TMA tv is scheduled + // the same way as the inner reduction tv. However, the inner dimension + // has not been scheduled yet. Since 1D TMA allows unrestricted load size, + // we can simply parallelize the entire inner dimension using bulk. + // Example: 2D tensor, [BIDy, S, | Bulk] + // Example: 1D tensor, [Bulk] + if (rparams->tma_warp_specialized) { + for (auto tv : tma_load_tvs) { + tv->axis(-1)->parallelize(ParallelType::Bulk); + } + } + + // special vectorization of temp gmem, vectorization_factor_tmp_gmem_write + // is guaranteed to be smaller or equal to input vectorization factor. + if (rparams->vectorization_factor_tmp_gmem_write > 1) { + for (auto tv : cached_gmem) { + NVF_ERROR( + rparams->vectorization_factor_tmp_gmem_write <= + rparams->unroll_factor_inner_reduction, + "vectorization factor of temp gmem write should be smaller than that of inner reduction.") + if (rparams->vectorization_factor_tmp_gmem_write < + rparams->unroll_factor_inner_reduction) { + tv->split(-1, rparams->vectorization_factor_tmp_gmem_write); + } + tv->axis(-1)->parallelize(ParallelType::Vectorize); + } + } + // vectorization propagate through propagateParallelization only works for + // input and output tensors. propagate vectorization to cached_gmem_reload + // directly from output tv using parallelizeAllLike. must propagate + // seperaely for different tvs as outer reductions are transformed + // seperately. + if (rparams->vectorization_factor_outer > 1) { + for (auto tv : cached_gmem_reload) { + auto output_tvs = ir_utils::outputTvsOf(tv); + NVF_ERROR( + !output_tvs.empty(), + "cached_gmem_reload should have at least one output tensor.") + scheduler_utils::parallelizeAllLike( + output_tvs[0], + -1, + {cached_gmem_reload.begin(), cached_gmem_reload.end()}, + {ParallelType::Vectorize}); + } + } + + // Needs special handling of vectorized loading from shared memory due to + // potential different data types of inputs and shared memory tensor. + if (is_vectorize) { + reduction_scheduler_utils::sharedMemoryConsumerVectorization( + smem_consumers, rparams->unroll_factor_inner_reduction); + } + + // Remove dummy outputs as they can inadvertently affect CA positions + for (auto output : dummy_outputs) { + fusion->removeOutput(output); + } + inlineMost(); +} +} // namespace inner_outer_tma_warp_specialized +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_tma_ws.h b/csrc/scheduler/normalization_inner_outer_tma_ws.h new file mode 100644 index 00000000000..f3d05586508 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_tma_ws.h @@ -0,0 +1,31 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +namespace inner_outer_tma_warp_specialized { +void getHeuristics( + ReductionParams* rparams, + const int64_t outer_dim_numel, + const int64_t inner_dim_numel, + const int64_t regs_buffer_size, + const int64_t smem_buffer_size, + const int64_t smem_overhead, + const size_t tmp_gmem_dtype_size, + const size_t vectorize_factor, + const int64_t hp_threads_per_block_min, + const int64_t hp_threads_per_block_max, + const bool project_to_input, + const PrimDataType index_type); + +void scheduleFusion(Fusion* fusion, const ReductionParams* rparams); +} // namespace inner_outer_tma_warp_specialized +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_utils.cpp b/csrc/scheduler/normalization_inner_outer_utils.cpp new file mode 100644 index 00000000000..bcaaa3db131 --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_utils.cpp @@ -0,0 +1,301 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { +namespace inner_outer_utils { + +int64_t roundUpSharedMemory( + int64_t tv_buffer_size, + int64_t data_type_size, + int64_t vectorize_factor, + int64_t threads_per_block_min, + int64_t threads_per_block_max, + int64_t threads_per_block_step) { + int64_t dim_size = tv_buffer_size / data_type_size; + int64_t after_vect = dim_size / vectorize_factor; + int64_t max_smem = 0; + for (int64_t threads_per_block = threads_per_block_min; + threads_per_block <= threads_per_block_max; + threads_per_block += threads_per_block_step) { + int64_t n_batch = ceilDiv(after_vect, threads_per_block); + max_smem = std::max( + max_smem, + n_batch * vectorize_factor * threads_per_block * data_type_size); + } + return max_smem; +} + +std::vector getOuterBroadcastTvs( + Fusion* fusion, + const std::vector& reduction_tvs) { + // set reference broadcast mask using the first inner reduction tv + std::vector ref_broadcast_mask; + for (auto tv : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(tv)) { + const auto& logical = tv->getLogicalDomain(); + ref_broadcast_mask.reserve(logical.size()); + for (const auto i : arange(logical.size())) { + ref_broadcast_mask.push_back(!logical.at(i)->isReduction()); + } + break; + } + } + NVF_ERROR(!ref_broadcast_mask.empty(), "ref_broadcast_mask is empty!"); + + // find the broadcast tensor whose broadcast mask is same to the reference + std::vector outer_broadcast_tvs; + for (auto tv : fusion->allTvs()) { + if (std::any_of( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + [](IterDomain* id) { return id->isBroadcast(); })) { + if (auto bcast = dynamic_cast(tv->definition())) { + if (bcast->getBroadcastDimFlags() == ref_broadcast_mask) { + outer_broadcast_tvs.emplace_back(tv); + } + } + } + } + return outer_broadcast_tvs; +} + +int64_t partialOuterReductionBufferSize( + const std::vector& reduction_tvs, + SchedulerRuntimeInfo& runtime_info) { + int64_t partial_reduction_buffer_size = 0; + for (auto buffer : reduction_tvs) { + if (scheduler_utils::isFastestDimReduction(buffer)) { + continue; + } + int64_t buffer_size = -1; + for (auto id : buffer->getLogicalDomain()) { + if (id->isReduction() || id->isBroadcast()) { + continue; + } + auto id_size = runtime_info.expressionEvaluator().evaluate(id->extent()); + NVF_ERROR(id_size.hasValue(), "Could not infer persistent buffer size."); + if (buffer_size == -1) { + buffer_size = id_size.as(); + } else { + buffer_size *= id_size.as(); + } + } + buffer_size = (buffer_size == -1) ? 0 + : buffer_size * + (int64_t)dataTypeSize(buffer->getDataType().value(), + runtime_info.getIndexType()); + partial_reduction_buffer_size += buffer_size; + } + return partial_reduction_buffer_size; +} + +std::vector sortProjectableBufferInputs( + const std::vector& projectable_buffer_inputs, + const std::vector& outer_broadcast_tvs) { + // mark whether the buffer is used by outer broadcast tensors + std::unordered_map is_used_by_outer_bcast; + for (auto buffer : projectable_buffer_inputs) { + is_used_by_outer_bcast[buffer] = std::any_of( + outer_broadcast_tvs.begin(), + outer_broadcast_tvs.end(), + [&buffer](TensorView* tv) { + return DependencyCheck::isDependencyOf(buffer, tv); + }); + } + + // sort based on [is_used_by_outer_bcast] + std::vector sorted_buffer = projectable_buffer_inputs; + std::sort( + sorted_buffer.begin(), + sorted_buffer.end(), + [&](TensorView* a, TensorView* b) { + return !is_used_by_outer_bcast[a] && is_used_by_outer_bcast[b]; + }); + return sorted_buffer; +} + +PersistentBufferStorageParams getPersistentBufferStorageParams( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache, + const std::vector& reduction_tvs, + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max) { + FUSER_PERF_SCOPE( + "normalization_inner_outer::getPersistentBufferStorageParams"); + + PersistentBufferStorageParams buffer_params; + + auto persistent_buffer_info_entry = + HeuristicDataCacheEntry( + data_cache, [&fusion]() { + return std::make_unique( + scheduler_utils::persistentBuffers(fusion)); + }); + + auto& persistent_buffer_info = persistent_buffer_info_entry.get(); + + auto persistent_buffer_size_info = scheduler_utils::persistentBufferSize( + fusion, runtime_info, persistent_buffer_info, data_cache); + + // Project to inputs when there is at least one outer broadcast tensor or + // projected persistent buffer size is smaller. When projecting to inputs, the + // outer broadcast tensor is reused in the loop over the iteration dimension, + // test shows it is faster than the non-projected version which requires + // reload from gmem for each iteration. + // Note: in current use cases (layer norm bwd and RMS norm bwd), there are + // outer broadcast tvs and always project to inputs. + // Warp specialized persistent kernel always cache inputs in shared memory, + // should project to inputs. + const auto& outer_broadcast_tvs = getOuterBroadcastTvs(fusion, reduction_tvs); + bool skip_check_buffer_size = !outer_broadcast_tvs.empty() || + isOptionEnabled(EnableOption::WarpSpecializedNormalization); + normalization_scheduler_utils::BufferProjectionStrategy project_strategy = + normalization_scheduler_utils::isProjectBufferToInputs( + fusion, + runtime_info, + reduction_tvs, + persistent_buffer_info, + persistent_buffer_size_info, + InnerOuterPersistentKernelScheduler::schedulerType(), + /*can_use_smem_persistent=*/true, + !skip_check_buffer_size); + + buffer_params.project_to_input = + (project_strategy == + normalization_scheduler_utils::BufferProjectionStrategy:: + ProjectToInputs); + + const auto dev_prop = at::cuda::getCurrentDeviceProperties(); + int64_t smem_overhead = scheduler_utils::getSharedMemoryOverheadPerBlock( + fusion, reduction_tvs, threads_per_block_max); + int64_t available_smem = + (int64_t)dev_prop->sharedMemPerMultiprocessor - smem_overhead; + int64_t available_regs = scheduler_utils::register_file_size_56k; + buffer_params.smem_overhead = smem_overhead; + + // (1) Use both register and shared memory. + // Start with all the cached input buffers in shared memory, they are loaded + // from global memory uses async copy which bypasses L1 cache. Outer reduction + // buffers are used to accumulate partial results of the outer reduction. They + // are not loaded from global memory and requires frequent read/write. So, + // they are always stored in registers. + // TODO: We may also move outer reduction buffers to shared + // memory to avoid segmentation when there are many outer reductions and + // hardware has larger shared memory, but these applications are rare, so this + // is not considered here. + auto buffers = buffer_params.project_to_input + ? persistent_buffer_info.projectable_buffer_inputs + : persistent_buffer_info.persistent_buffers; + + // Add buffers that are inputs to the fusion. They are not included in + // projectable_buffer_inputs since they are not projectable. + if (buffer_params.project_to_input) { + for (auto tv : persistent_buffer_info.persistent_buffers) { + if (tv->isFusionInput()) { + buffers.push_back(tv); + } + } + } + + // Needs to use rounded shared memory size to avoid over usage. + // key : buffer tv. + // val : register size and rounded shared memory size + std::unordered_map> + required_size_regs_smem_map; + int64_t total_smem_buffer_size = 0; + for (auto buffer : buffers) { + int64_t buffer_size_regs = scheduler_utils::getPersistentBufferSizeOfTensor( + buffer, runtime_info, persistent_buffer_info); + int64_t buffer_size_smem = roundUpSharedMemory( + buffer_size_regs, + dataTypeSize(buffer->getDataType().value()), + vectorize_factor, + threads_per_block_min, + threads_per_block_max, + dev_prop->warpSize); + required_size_regs_smem_map[buffer] = + std::make_pair(buffer_size_regs, buffer_size_smem); + total_smem_buffer_size += buffer_size_smem; + } + buffer_params.smem_buffer_size = total_smem_buffer_size; + buffer_params.regs_buffer_size = + partialOuterReductionBufferSize(reduction_tvs, runtime_info); + if (buffer_params.regs_buffer_size <= available_regs && + buffer_params.smem_buffer_size <= available_smem) { + buffer_params.smem_persistent_buffers = buffers; + buffer_params.has_enough_regs_and_smem = true; + return buffer_params; + } + + // Moving outer reduction buffer to shared memory is not considered yet, + // set to false if the outer reduction buffer size exceeds the register size. + if (buffer_params.regs_buffer_size > available_regs) { + buffer_params.has_enough_regs_and_smem = false; + return buffer_params; + } + + // (2) Now, shared memory is overused, move some buffers to registers. + // (2.1) Sort the candidate persistent buffers. No need to sort since the + // sorting is based on whether the buffer is used by outer broadcast tensors. + if (!outer_broadcast_tvs.empty()) { + buffers = sortProjectableBufferInputs(buffers, outer_broadcast_tvs); + } + // (2.2) Before this loop, all cached input buffers are in shared memory. Move + // buffer from shared memory to register. + int64_t n_regs_buffer = -1; + const int n_buffers = (int)buffers.size(); + for (int i = 0; i < n_buffers; i++) { + auto current_tv = buffers[i]; + auto [buffer_size_regs, buffer_size_smem] = + required_size_regs_smem_map.at(current_tv); + buffer_params.regs_buffer_size += buffer_size_regs; + buffer_params.smem_buffer_size -= buffer_size_smem; + + // The first-i buffers to are moved from shared memory to register + // If both the register buffer size and shared memory buffer size are within + // the allowable limit, we found a good configuration. + if (buffer_params.regs_buffer_size <= available_regs && + buffer_params.smem_buffer_size <= available_smem) { + n_regs_buffer = i + 1; + break; + } + // Register buffer size exceeds the limit, can't move more to registers. + // Break the loop. + if (buffer_params.regs_buffer_size > available_regs) { + break; + } + } + + // n_regs_buffer > 0 indicats a good configuration is found. + // The first n_regs_buffer buffers are stored in registers and last [n_buffers + // - n_regs_buffer] are stored in shared memory. + if (n_regs_buffer > 0) { + buffer_params.has_enough_regs_and_smem = true; + auto n_smem_buffer = n_buffers - n_regs_buffer; + buffer_params.smem_persistent_buffers.reserve(n_smem_buffer); + for (int i = 0; i < n_smem_buffer; i++) { + buffer_params.smem_persistent_buffers.emplace_back( + buffers[n_buffers - 1 - i]); + } + } else { + buffer_params.has_enough_regs_and_smem = false; + } + return buffer_params; +} + +} // namespace inner_outer_utils +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_inner_outer_utils.h b/csrc/scheduler/normalization_inner_outer_utils.h new file mode 100644 index 00000000000..49b0699a00c --- /dev/null +++ b/csrc/scheduler/normalization_inner_outer_utils.h @@ -0,0 +1,98 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { +class SchedulerRuntimeInfo; +class HeuristicDataCache; + +namespace inner_outer_utils { +// The roundup is due to the fact that the shared memory buffer is allocated +// as: ceilDiv(dim_size / vectorize_factor, threads_per_block). +// Let after_vect = dim_size / vectorize_factor; +// n_batch = ceilDiv(after_vect, threads_per_block); +// Then the shared memory buffer size is n_batch * vectorize_factor * +// threads_per_block * data_type_size. This function returns the maximum +// possible shared memory buffer size considering all possible block sizes. +int64_t roundUpSharedMemory( + int64_t tv_buffer_size, + int64_t data_type_size, + int64_t vectorize_factor, + int64_t threads_per_block_min, + int64_t threads_per_block_max, + int64_t threads_per_block_step); + +// Return the broadcast tvs that are broadcast to the iteration dimensions of +// the inner reduction tv. These tvs are reused in the loop over the iteration +// dimension. This reuse reduced the number loads from gmem and this tensor +// is likely the first candidate to be moved to shared memory when the register +// space runs low. +std::vector getOuterBroadcastTvs( + Fusion* fusion, + const std::vector& reduction_tvs); + +// Size of buffers storing intermediate outer reduction results +// TODO: check if we can directly start with [buffer_size = 1] +int64_t partialOuterReductionBufferSize( + const std::vector& reduction_tvs, + SchedulerRuntimeInfo& runtime_info); + +// Decide where to store persistent buffers. +// By default, they reside in registers. +// If register space runs low but there's ample shared memory, +// move one or more buffers to shared memory until the register space is +// sufficient. +struct PersistentBufferStorageParams { + // representing buffers that are stored in shared memory, other buffers are + // stored in registers. + std::vector smem_persistent_buffers; + + // Total number of bytes occupied by all persistent buffers stored in shared + // memory. + int64_t smem_buffer_size = -1; + + // Total number of bytes occupied by all persistent buffers stored in + // registers. + int64_t regs_buffer_size = -1; + + // Additional shared memory usage per block that is not associated with + // persistent buffers. This includes memory for driver overhead and workspace + // for reductions. + int64_t smem_overhead = -1; + + // Flag indicating whether there are sufficient registers and shared memory + // available to accommodate all persistent buffers as required for efficient + // execution. + bool has_enough_regs_and_smem = false; + + // Flag indicating whether the persistent buffers are recomputed using inputs. + bool project_to_input = false; +}; +PersistentBufferStorageParams getPersistentBufferStorageParams( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache, + const std::vector& reduction_tvs, + const int64_t vectorize_factor, + const int64_t threads_per_block_min, + const int64_t threads_per_block_max); + +// Prioritize keeping buffers used by outer broadcast tensors to shared memory +// because: +// (1) They are reused in every iteration of the outer loop, has lower IO. +// (2) Load occurs before the outer loop. Temporary register usage won't +// increase register pressure since the loop is the high-pressure region. +std::vector sortProjectableBufferInputs( + const std::vector& projectable_buffer_inputs, + const std::vector& outer_broadcast_tvs); + +} // namespace inner_outer_utils +} // namespace nvfuser diff --git a/csrc/scheduler/normalization_utils.cpp b/csrc/scheduler/normalization_utils.cpp index a92418998b8..3d20f9a5a5b 100644 --- a/csrc/scheduler/normalization_utils.cpp +++ b/csrc/scheduler/normalization_utils.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include #include #include @@ -1163,6 +1164,19 @@ bool compileTimeCheck(Fusion* fusion, SchedulerType scheduler_type) { scheduler_type, "no reduction tv"); return false; } + + // Reject when output IDs are not covered by reference tv. Assuming reduction + // scheduler simply uses reduction_tvs[0] as the reference, if that changes, + // this needs to be changed. see issue + // https://github.com/NVIDIA/Fuser/issues/3811 + scheduler_tools::DomainMap domain_map(fusion); + if (!domain_map.isValidReference(reduction_tvs[0], /*check_inputs=*/false)) { + scheduler_debug_utils::canScheduleRejectReason( + scheduler_type, + "Output contains ID that's not scheduled by reference tv."); + return false; + } + auto reduction_type = reduction_scheduler_utils::getReductionType(reduction_tvs); const SchedulerType persistent_heuristic = @@ -1319,11 +1333,15 @@ std::vector movePersistentBufferToSmem( } if (use_smem) { tv->setMemoryType(MemoryType::Shared); - // When loading from global memory (gmem), use CpAsync with a short data - // path of gmem -> smem to reduce temporary register usage. Otherwise, the - // data path from gmem to shared memory (smem) follows this sequence: gmem - // -> L1 cache -> register -> smem. - if (supportCpAsync(tv) && is_cached_input) { + // Use 1D TMA, CpAsyncBulk + if (rparams->tma_warp_specialized && is_cached_input) { + tv->definition()->as()->setOpType( + LoadStoreOpType::CpAsyncBulk); + } else if (supportCpAsync(tv) && is_cached_input) { + // When loading from global memory (gmem), use CpAsync with a short data + // path of gmem -> smem to reduce temporary register usage. Otherwise, + // the data path from gmem to shared memory (smem) follows this + // sequence: gmem -> L1 cache -> register -> smem. tv->definition()->as()->setOpType( LoadStoreOpType::CpAsync); tv->definition()->as()->setCacheOp(CacheOp::Unspecified); diff --git a/csrc/scheduler/reduction.cpp b/csrc/scheduler/reduction.cpp index bf5a7b2e38c..2ca11c97346 100644 --- a/csrc/scheduler/reduction.cpp +++ b/csrc/scheduler/reduction.cpp @@ -15,6 +15,7 @@ #include #include #include +#include #include #include #include @@ -1670,6 +1671,18 @@ bool ReductionScheduler::canScheduleCompileTime(Fusion* fusion) { return false; } + // Reject when output IDs are not covered by reference tv. Assuming reduction + // scheduler simply uses reduction_tvs[0] as the reference, if that changes, + // this needs to be changed. see issue + // https://github.com/NVIDIA/Fuser/issues/3811 + scheduler_tools::DomainMap domain_map(fusion); + if (!domain_map.isValidReference(reduction_tvs[0], /*check_inputs=*/false)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Output contains ID that's not scheduled by reference tv."); + return false; + } + if (registry_utils::hasNonUniqueBcast(fusion)) { scheduler_debug_utils::canScheduleRejectReason( schedulerType(), diff --git a/csrc/scheduler/reduction_utils.cpp b/csrc/scheduler/reduction_utils.cpp index a8c7652d196..96c7f8650b3 100644 --- a/csrc/scheduler/reduction_utils.cpp +++ b/csrc/scheduler/reduction_utils.cpp @@ -303,8 +303,10 @@ TensorView* scheduleReductionTV( } } } - - auto reduction_rf_tv = sortAndRFactor(reduction_tv); + const bool is_non_persistent_outer_reduction = + !rparams->persistent_kernel && !rparams->fastest_dim; + auto reduction_rf_tv = + sortAndRFactor(reduction_tv, is_non_persistent_outer_reduction); // In the case of outer grid persistence, make sure the vectorized // domain placed at the innermost position. @@ -647,7 +649,9 @@ bool placedBefore(const IterDomain* id0, const IterDomain* id1) { } } // namespace -TensorView* sortAndRFactor(TensorView* reference_tv) { +TensorView* sortAndRFactor( + TensorView* reference_tv, + bool is_non_persistent_outer_reduction) { auto domain = reference_tv->getLoopDomain(); std::sort(domain.begin(), domain.end(), placedBefore); std::unordered_map reorder_map; @@ -659,6 +663,41 @@ TensorView* sortAndRFactor(TensorView* reference_tv) { reorder_map[old_i] = domain_pos.at(reference_tv->axis(old_i)); } reference_tv->reorder(reorder_map); + // For outer reduction, if an Id after vectorization Id is a constant + // serial Id, swap it with the vectorization Id to reduce register usage. + // For example, in a thread-local outer reduction, we want to transform: + // [..., iV{8}, rS{7}, rUS{1}, rUR{4}] + // to: + // [..., rS{7}, iV{8}, rUS{1}, rUR{4}] + // After change, each thread only needs to cache 8 × 4 elements instead of + // 8 × 7 × 4 elements. + // See https://github.com/NVIDIA/Fuser/issues/4172 for real examples. + if (is_non_persistent_outer_reduction) { + auto vect_iter = + std::find_if(domain.begin(), domain.end(), [](IterDomain* id) { + return id->getParallelType() == ParallelType::Vectorize; + }); + if (vect_iter != domain.end()) { + int64_t vect_id_pos = vect_iter - domain.begin(); + std::unordered_map reorder_map; + for (auto iter = vect_iter + 1; iter != domain.end(); iter++) { + if ((*iter)->getParallelType() == ParallelType::Serial && + (*iter)->extent()->isConstScalar()) { + int64_t id_pos = iter - domain.begin(); + reorder_map[id_pos] = vect_id_pos++; + } + } + // Although we support reordering multiple constant serial IDs after the + // vectorization ID, the current scheduler only emits one. It may be worth + // exploring performance implications if multiple such IDs are introduced + // in the future. + NVF_ERROR( + reorder_map.size() <= 1, + "Expect one constant serial Id after vectorization Id, but found ", + reorder_map.size()); + reference_tv->reorder(reorder_map); + } + } std::vector rfactor_axes; std::vector rfactor_axes_no_unswitch; diff --git a/csrc/scheduler/reduction_utils.h b/csrc/scheduler/reduction_utils.h index 9985ad83ee1..abb8253f636 100644 --- a/csrc/scheduler/reduction_utils.h +++ b/csrc/scheduler/reduction_utils.h @@ -101,7 +101,9 @@ NVF_API void propagateParallelization( // Rfactored axes are reductions bound to grid or blocks. If no axes are bound // to a grid or block dimension it will rfactor the r-unswitch dimension. // Reduction inliner expects an rfactored domain. -NVF_API TensorView* sortAndRFactor(TensorView* reference_tv); +NVF_API TensorView* sortAndRFactor( + TensorView* reference_tv, + const bool is_non_persistent_outer_reduction = false); // If project_to_inputs is true, take all projectable persistent buffers, // and move them to the inputs. Otherwise, try to project to their immediate diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index 33316c9480e..6e3261008e6 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -35,7 +35,12 @@ bool checkCanSchedule(Fusion* fusion, SchedulerType scheduler_type) { // These ops are are only accepted in `ExprEval` // scheduler, all other schedulers should reject them. - if (ir_utils::hasOpsOfType(fusion)) { + // TODO: remove IndexPutAccumulateOp + if (ir_utils::hasOpsOfType< + SdpaFwdOp, + SdpaBwdOp, + EmbeddingFwdOp, + IndexPutAccumulateOp>(fusion)) { scheduler_debug_utils::canScheduleRejectReason( scheduler_type, "Has unsupported ops"); return false; diff --git a/csrc/scheduler/tools/domain_map.cpp b/csrc/scheduler/tools/domain_map.cpp index 87dcbd358ee..f129eb6b78a 100644 --- a/csrc/scheduler/tools/domain_map.cpp +++ b/csrc/scheduler/tools/domain_map.cpp @@ -376,15 +376,18 @@ IterDomain* DomainMap::anyMapped( // Determine if output TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output -bool DomainMap::isValidReference(TensorView* tv) const { - for (auto input_tv : ir_utils::filterByType(fusion_->inputs())) { - if (input_tv->uses().empty()) { - continue; - } - // TODO: Same backward traversal from tv is done for all input - // tvs. Consider doing the analysis one for all inputs - if (!areAllInputIdsMappedTo(input_tv, tv)) { - return false; +bool DomainMap::isValidReference(TensorView* tv, bool check_inputs) const { + if (check_inputs) { + for (auto input_tv : + ir_utils::filterByType(fusion_->inputs())) { + if (input_tv->uses().empty()) { + continue; + } + // TODO: Same backward traversal from tv is done for all input + // tvs. Consider doing the analysis one for all inputs + if (!areAllInputIdsMappedTo(input_tv, tv)) { + return false; + } } } // The check on outputs are optional, transpose scheduler might propose a diff --git a/csrc/scheduler/tools/domain_map.h b/csrc/scheduler/tools/domain_map.h index 8a8ccb33e91..d6ed2a3a367 100644 --- a/csrc/scheduler/tools/domain_map.h +++ b/csrc/scheduler/tools/domain_map.h @@ -34,7 +34,7 @@ class DomainMap { // Determine if a TensorView is a valid reference tensor for this fusion. // The reference tensor must map to all the iterDomains in each input and // output. - bool isValidReference(TensorView* tv) const; + bool isValidReference(TensorView* tv, bool check_inputs = true) const; protected: // Determine if all IterDomains are mapped between input and the given tvs diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 6ec81585869..091f3519a0d 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -842,6 +842,59 @@ std::vector> getTvToContigInnerSizeMapsOf( return mappers; } +// Check if a traversal from vectorized reference IDs may reach the +// IDs of a resize expr without visiting the Resize expr itself. That's +// problematic for the vectorization analysis as the spanning-tree +// based analysis may miss the constraint by the Resize expr. +// +// For this analysis, we start a traversal from the vectorized +// reference IDs to both the input and output of the Resize expr but +// disallow visiting the Resize expr itself. If the traversal is still +// successful, it means there's a path from the reference IDs to the +// resize input and output IDs without visiting the Resize expr. +// +// Permissive BFS is used in this traversal as the vectorized +// reference IDs may not have all the dependencies for the +// traversal. For example, suppose there's a split resshape, and only +// the innermost ID is vectorized. The standard BFS is not able to +// move forward if only the vectorized ID is give as the backward +// split requires both outputs to be presented. +class CanSkipResize : public ValGraphPermissiveBFS { + public: + static bool run( + const ValGraph& graph, + const ValGroups& ref_groups, + Resize* resize) { + ValGroups resize_in_out_groups; + resize_in_out_groups.pushBack(graph.toGroup(resize->in())); + resize_in_out_groups.pushBack(graph.toGroup(resize->out())); + CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize); + bfs.traverse(); + return bfs.allToNodesVisited(); + } + + CanSkipResize( + const ValGraph& graph, + const ValGroups& ref_groups, + const ValGroups& resize_in_out_groups, + Resize* resize) + : ValGraphPermissiveBFS( + graph, + {ref_groups.begin(), ref_groups.end()}, + {resize_in_out_groups.begin(), resize_in_out_groups.end()}, + /*require_all_to_visited=*/false, + /*allowed_direction=*/Direction::Undefined), + resize_(resize) {} + + bool excludeFromTraversal(const NodeType& node) const override { + const ExprGroup* e = std::get_if(&node); + return e != nullptr && (*e)->has(resize_); + } + + private: + Resize* resize_ = nullptr; +}; + // This is a WAR for vectorizing through resized iter domains. The // spanning tree based analysis is not guaranteed to take all resize // ops into considerations (issue @@ -852,84 +905,48 @@ std::unordered_set getResizeVectorizationFactors( TensorView* reference_tv, int64_t break_point) { Fusion* fusion = reference_tv->fusion(); - std::unordered_set factors; const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); if (resize_based_ops.empty()) { - return factors; + return {}; } - IdModel id_model(reference_tv->fusion()); + IdModel id_model(fusion); const auto& graph = id_model.buildExactGraph(); - const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); + std::unordered_set resize_factors; - // For each of resize-based tensor ops, find all resize ops - // that exist between the vectorized reference IDs and the output - // tensor. - for (auto resize_based_op : resize_based_ops) { - auto resize_out = resize_based_op->output(0)->as(); - NVF_ERROR( - resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); - // getAllExprGroupsBetween finds exprs between IDs. To make sure - // the the resize op of this resize_based_op tensor op is found, - // use both the root and logical domains as the traversal targets. - ValGroups resize_inp_out; - resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); - resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); - - auto expr_path = getAllExprGroupsBetween( - graph, - ref_groups, - resize_inp_out, - /*require_all_to_visited=*/false) - .first; - - ValGroups vectorized_groups; - for (auto it = reference_tv->getLogicalDomain().begin() + break_point; - it != reference_tv->getLogicalDomain().end(); - ++it) { - vectorized_groups.pushBack(graph.toGroup(*it)); + auto add_resize_factors = [&](Resize* resize) { + if (!resize->leftExpand()->isZeroInt()) { + resize_factors.insert(resize->leftExpand()); } + if (!resize->rightExpand()->isZeroInt()) { + resize_factors.insert(resize->rightExpand()); + } + }; - // Find all resize exprs that appear in expr_path and depend on - // vectorized_groups. Since expr_path is not guaranteed to be - // topologically sorted, need to loop through the path until - // converged. - - bool something_has_changed = true; - while (something_has_changed) { - something_has_changed = false; - for (const auto& [expr_g, dir] : expr_path) { - const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); - if (std::none_of( - inputs.begin(), inputs.end(), [&](const ValGroup& inp) { - return vectorized_groups.has(inp); - })) { - continue; - } - - if (vectorized_groups.pushBack( - getOutputsOfExprGroup(graph, expr_g, dir))) { - something_has_changed = true; - } - - auto resize = dynamic_cast(expr_g->front()); - if (resize == nullptr) { - continue; - } + const ValGroups ref_vec_groups = graph.toGroups(std::vector{ + reference_tv->getLogicalDomain().begin() + break_point, + reference_tv->getLogicalDomain().end()}); + + // For each of Resize exprs, if it's reachable from the reference + // vectorized IDs without visiting the Resize expr itself, its + // constraint may not be reflectd in the inner sizes. + for (auto resize : resize_based_ops) { + auto resize_out_tv = resize->output(0)->as(); + for (const auto logical_id : resize_out_tv->getLogicalDomain()) { + auto resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } - // These three vals need to be divisible - factors.emplace(resize->leftExpand()); - factors.emplace(resize->rightExpand()); - factors.emplace( - dir == Direction::Forward ? resize->out()->extent() - : resize->in()->extent()); + if (CanSkipResize::run(graph, ref_vec_groups, resize)) { + add_resize_factors(resize); } } } - return factors; + return resize_factors; } } // namespace @@ -1028,7 +1045,11 @@ int64_t getVectorizationFactor( if (!inferred_val.hasValue()) { return 1; } - max_vec_size = std::gcd(max_vec_size, inferred_val.as()); + auto inferred_val_int = inferred_val.as(); + if (inferred_val_int == 0) { + continue; + } + max_vec_size = std::gcd(max_vec_size, inferred_val_int); } return max_vec_size; diff --git a/csrc/serde/fusion_cache.fbs b/csrc/serde/fusion_cache.fbs index c2b90b08a5c..def7a760eea 100644 --- a/csrc/serde/fusion_cache.fbs +++ b/csrc/serde/fusion_cache.fbs @@ -48,6 +48,7 @@ enum RecordType: int { FullOp, IotaOp, IndexSelectOp, + IndexPutAccumulateOp, SelectOp, GatherOp, TakeAlongAxisOp, diff --git a/csrc/serde/fusion_record.cpp b/csrc/serde/fusion_record.cpp index ce9412770d9..0233174d260 100644 --- a/csrc/serde/fusion_record.cpp +++ b/csrc/serde/fusion_record.cpp @@ -472,6 +472,13 @@ void RecordFunctorFactory::registerAllParsers() { }; registerParser(RecordType::IndexSelectOp, deserializeIndexSelectRecord); + auto deserializeIndexPutAccumulateRecord = [](const RecordFunctor* buffer) { + return new python_frontend::IndexPutAccumulateOpRecord( + parseStateArgs(buffer->args()), parseStateArgs(buffer->outputs())); + }; + registerParser( + RecordType::IndexPutAccumulateOp, deserializeIndexPutAccumulateRecord); + auto deserializeSelectRecord = [](const RecordFunctor* buffer) { return new python_frontend::SelectOpRecord( parseStateArgs(buffer->args()), diff --git a/csrc/type.cpp b/csrc/type.cpp index d1a5b2abd80..31163c7ced7 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -729,7 +729,7 @@ static const char* parallel_type2string(ParallelType t) { case ParallelType::TIDx: return "threadIdx.x"; case ParallelType::Stream: - return "Stream"; + return "StreamIdx"; case ParallelType::Vectorize: return "V"; case ParallelType::Unroll: @@ -1629,8 +1629,8 @@ std::ostream& operator<<( case CircularBufferLoopStage::Epilog: os << "{CircularBufferEpilog}"; break; - case CircularBufferLoopStage::LoadWarp: - os << "{LoadWarp}"; + case CircularBufferLoopStage::AsyncWarp: + os << "{AsyncWarp}"; break; case CircularBufferLoopStage::ComputeWarp: os << "{ComputeWarp}"; diff --git a/csrc/type.h b/csrc/type.h index 6a032aa5495..cb8561b25b8 100644 --- a/csrc/type.h +++ b/csrc/type.h @@ -819,7 +819,7 @@ enum class CircularBufferLoopStage { Prolog = 0, Main, Epilog, - LoadWarp, + AsyncWarp, ComputeWarp, EndOfStages, // A special placeholder used to iterate over all stages NotApplicable @@ -831,7 +831,7 @@ enum class CircularBufferLoopStage { inline bool hasCircularBufferLoad(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Prolog || stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::LoadWarp; + stage == CircularBufferLoopStage::AsyncWarp; } // The consuming expressions of circular buffer are cloned for these circular @@ -851,7 +851,7 @@ inline bool hasCircularBufferConsume(CircularBufferLoopStage stage) { // somewhere (*may or may not be in this loop*) inline bool mayHaveWarHazard(CircularBufferLoopStage stage) { return stage == CircularBufferLoopStage::Main || - stage == CircularBufferLoopStage::LoadWarp || + stage == CircularBufferLoopStage::AsyncWarp || stage == CircularBufferLoopStage::ComputeWarp; } diff --git a/nvfuser b/nvfuser new file mode 120000 index 00000000000..25e57deb181 --- /dev/null +++ b/nvfuser @@ -0,0 +1 @@ +python/nvfuser \ No newline at end of file diff --git a/python/LICENSE b/python/LICENSE new file mode 120000 index 00000000000..ea5b60640b0 --- /dev/null +++ b/python/LICENSE @@ -0,0 +1 @@ +../LICENSE \ No newline at end of file diff --git a/nvfuser/README.md b/python/nvfuser/README.md similarity index 100% rename from nvfuser/README.md rename to python/nvfuser/README.md diff --git a/nvfuser/__init__.py b/python/nvfuser/__init__.py similarity index 100% rename from nvfuser/__init__.py rename to python/nvfuser/__init__.py diff --git a/nvfuser/__init__.pyi b/python/nvfuser/__init__.pyi similarity index 100% rename from nvfuser/__init__.pyi rename to python/nvfuser/__init__.pyi diff --git a/python/nvfuser/benchmark_utils.py b/python/nvfuser/benchmark_utils.py new file mode 100644 index 00000000000..4949bbc599b --- /dev/null +++ b/python/nvfuser/benchmark_utils.py @@ -0,0 +1,105 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from torch.autograd import DeviceType +from torch.profiler import profile, ProfilerActivity +import torch + + +# Base class for all timers used by pytest-benchmark. +class Timer: + def __init__(self): + self.current_time = 0.0 + + def _increment_global_time(self, elapsed_time: float) -> None: + self.current_time += elapsed_time + + def __call__(self): + raise NotImplementedError("Subclass must implement this method") + + def cleanup(self): + pass + + +class TorchProfileTimer(Timer): + def __init__(self): + super().__init__() + self.prof = profile(activities=[ProfilerActivity.CUDA]) + + def _get_kernel_time( + self, prof_averages: torch.autograd.profiler_util.EventList + ) -> float: + """ + Arguments: + prof_averages: Output of self.prof.key_averages() + Returns: + time_value: Elapsed CUDA time in seconds. + """ + elapsed_cuda_time = 0 + has_cuda_event = False + for event in prof_averages: + if event.device_type != DeviceType.CUDA: + continue + has_cuda_event = True + # Re: torch profiler API changes in https://github.com/pytorch/pytorch/pull/123247 + elapsed_cuda_time = ( + elapsed_cuda_time + event.self_device_time_total + if hasattr(event, "self_device_time_total") + else event.self_cuda_time_total + ) + assert has_cuda_event, "No CUDA events found" + return elapsed_cuda_time / 1e6 + + def __call__(self): + """ + Custom torchprofiler-based timer used by pytest-benchmark. + At every timer call, the profiler is stopped to compute the elapsed CUDA time + and the global clock is incremented. The profiler is restarted before returning to continue tracing. + + Returns: + self.current_time: Global monotonic clock variable + """ + try: + self.prof.stop() + except AssertionError: + self.prof.start() + return self.current_time + + prof_averages = self.prof.key_averages() + elapsed_cuda_time = self._get_kernel_time(prof_averages) + self._increment_global_time(elapsed_cuda_time) + # Clear the internal profiler object to avoid accumulating function events and then restart the profiler + # See PR: https://github.com/pytorch/pytorch/pull/125510 + self.prof.profiler = None + + return self.current_time + + def cleanup(self): + """ + Stops a running torchprofiler instance if found. + """ + try: + self.prof.stop() + except AssertionError: + pass + + +class FusionProfileTimer(Timer): + def __init__(self): + super().__init__() + self.fd = None + # Specifies if the timer in host measurement is called at the start/finish of execution. + # Timings are measured at the end of execution. + self.execution_start = True + + def set_fd(self, fd): + self.fd = fd + + def __call__(self): + if not self.execution_start: + profile = self.fd.profile() + elapsed_host_time = profile.host_time_ms / 1e3 + self._increment_global_time(elapsed_host_time) + self.execution_start = not self.execution_start + return self.current_time diff --git a/nvfuser/contrib/__init__.py b/python/nvfuser/contrib/__init__.py similarity index 100% rename from nvfuser/contrib/__init__.py rename to python/nvfuser/contrib/__init__.py diff --git a/nvfuser/contrib/nn/__init__.py b/python/nvfuser/contrib/nn/__init__.py similarity index 100% rename from nvfuser/contrib/nn/__init__.py rename to python/nvfuser/contrib/nn/__init__.py diff --git a/nvfuser/contrib/nn/normalization.py b/python/nvfuser/contrib/nn/normalization.py similarity index 100% rename from nvfuser/contrib/nn/normalization.py rename to python/nvfuser/contrib/nn/normalization.py diff --git a/nvfuser/nvfuser_version.py b/python/nvfuser/nvfuser_version.py similarity index 100% rename from nvfuser/nvfuser_version.py rename to python/nvfuser/nvfuser_version.py diff --git a/nvfuser/pytorch_utils.py b/python/nvfuser/pytorch_utils.py similarity index 100% rename from nvfuser/pytorch_utils.py rename to python/nvfuser/pytorch_utils.py diff --git a/nvfuser/testing/__init__.py b/python/nvfuser/testing/__init__.py similarity index 100% rename from nvfuser/testing/__init__.py rename to python/nvfuser/testing/__init__.py diff --git a/nvfuser/testing/utils.py b/python/nvfuser/testing/utils.py similarity index 100% rename from nvfuser/testing/utils.py rename to python/nvfuser/testing/utils.py diff --git a/nvfuser/utils.py b/python/nvfuser/utils.py similarity index 100% rename from nvfuser/utils.py rename to python/nvfuser/utils.py diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 00000000000..d7813c1ed06 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,3 @@ +[build-system] +requires = ["setuptools>=42", "wheel", "ninja", "cmake>=3.18"] +build-backend = "setuptools.build_meta:__legacy__" diff --git a/csrc/python_frontend/distributed_tensor.cpp b/python/python_frontend/distributed_tensor.cpp similarity index 100% rename from csrc/python_frontend/distributed_tensor.cpp rename to python/python_frontend/distributed_tensor.cpp diff --git a/csrc/python_frontend/distributed_tensor.h b/python/python_frontend/distributed_tensor.h similarity index 100% rename from csrc/python_frontend/distributed_tensor.h rename to python/python_frontend/distributed_tensor.h diff --git a/csrc/python_frontend/fusion_cache.cpp b/python/python_frontend/fusion_cache.cpp similarity index 100% rename from csrc/python_frontend/fusion_cache.cpp rename to python/python_frontend/fusion_cache.cpp diff --git a/csrc/python_frontend/fusion_cache.h b/python/python_frontend/fusion_cache.h similarity index 100% rename from csrc/python_frontend/fusion_cache.h rename to python/python_frontend/fusion_cache.h diff --git a/csrc/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp similarity index 99% rename from csrc/python_frontend/fusion_definition.cpp rename to python/python_frontend/fusion_definition.cpp index c48abc9dbdc..b77947f1415 100644 --- a/csrc/python_frontend/fusion_definition.cpp +++ b/python/python_frontend/fusion_definition.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -452,6 +453,10 @@ std::pair> FusionDefinition:: if (scheds->multi_device_executor == nullptr) { MultiDeviceExecutorParams params; params.lower.communicator_backend = backend_type_; + // Disable StreamParallelType pass temporarily as proper stream lowering + // gets implemented + preseg_passes::OptimizationPassGuard guard( + false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), Communicator::getInstance(), diff --git a/csrc/python_frontend/fusion_definition.h b/python/python_frontend/fusion_definition.h similarity index 100% rename from csrc/python_frontend/fusion_definition.h rename to python/python_frontend/fusion_definition.h diff --git a/csrc/python_frontend/fusion_record.h b/python/python_frontend/fusion_record.h similarity index 99% rename from csrc/python_frontend/fusion_record.h rename to python/python_frontend/fusion_record.h index 3a6af8cfeb3..b437c5e247b 100644 --- a/csrc/python_frontend/fusion_record.h +++ b/python/python_frontend/fusion_record.h @@ -3095,6 +3095,30 @@ struct EmbeddingFwdOpRecord : RecordFunctor { } }; +struct IndexPutAccumulateOpRecord : RecordFunctor { + IndexPutAccumulateOpRecord( + std::vector args, + std::vector outputs) + : RecordFunctor( + std::move(args), + std::move(outputs), + "ops.index_put_accumulate", + serde::RecordType::IndexPutAccumulateOp) {} + ~IndexPutAccumulateOpRecord() override = default; + RecordFunctor* clone() final { + return new IndexPutAccumulateOpRecord(*this); + } + + void operator()(FusionState& fd) final { + auto acc = fd.getFusionState(args_.at(0).index)->as(); + auto index = fd.getFusionState(args_.at(1).index)->as(); + auto value = fd.getFusionState(args_.at(2).index)->as(); + + auto output = indexPutAccumulate(acc, index, value); + fd.setFusionState(outputs_.at(0).index, output); + } +}; + } // namespace nvfuser::python_frontend //! Creating the template specialized hash and equal_to functions for a diff --git a/csrc/python_frontend/fusion_state.cpp b/python/python_frontend/fusion_state.cpp similarity index 100% rename from csrc/python_frontend/fusion_state.cpp rename to python/python_frontend/fusion_state.cpp diff --git a/csrc/python_frontend/fusion_state.h b/python/python_frontend/fusion_state.h similarity index 100% rename from csrc/python_frontend/fusion_state.h rename to python/python_frontend/fusion_state.h diff --git a/csrc/python_frontend/multidevice_bindings.cpp b/python/python_frontend/multidevice_bindings.cpp similarity index 100% rename from csrc/python_frontend/multidevice_bindings.cpp rename to python/python_frontend/multidevice_bindings.cpp diff --git a/csrc/python_frontend/python_bindings.cpp b/python/python_frontend/python_bindings.cpp similarity index 99% rename from csrc/python_frontend/python_bindings.cpp rename to python/python_frontend/python_bindings.cpp index 123eec51263..d8b7ad291d4 100644 --- a/csrc/python_frontend/python_bindings.cpp +++ b/python/python_frontend/python_bindings.cpp @@ -3091,6 +3091,48 @@ void initNvFuserPythonBindings(PyObject* module) { py::arg("index"), py::arg("dim"), py::return_value_policy::reference); + nvf_ops.def( + "index_put_accumulate", + [](FusionDefinition::Operators& self, + Tensor acc, + Tensor index, + Tensor value) -> Tensor { + FUSER_PERF_SCOPE("Operators.index_put_accumulate"); + NVF_CHECK( + self.validUse(), "Attempting to add to a completed definition!"); + FusionDefinition* fd = self.fusion_definition; + Tensor output = fd->defineTensor(acc.dims); + fd->defineRecord(new IndexPutAccumulateOpRecord( + { + fd->recordingState(acc()), + fd->recordingState(index()), + fd->recordingState(value()), + }, + {fd->recordingState(output())})); + return output; + }, + py::arg("acc"), + py::arg("index"), + py::arg("value"), + py::return_value_policy::reference, + R"doc( + Accumulates values into a tensor at specified indices. + + This function performs a restricted version of `torch.index_put`. + It adds the values from `value_tv` to the elements of `acc_tv` at the indices + specified by `index_tv`. + + acc_tv: The tensor to accumulate into (in-place modification). + index_tv: The tensor containing the indices. + value_tv: The tensor containing the values to accumulate. + + Returns: + An alias to the modified `acc_tv` tensor. + + Note: + This is a restricted version and may not support all features of the + full `torch.index_put(..., accumulate=true)` function. + )doc"); nvf_ops.def( "select", [](FusionDefinition::Operators& self, diff --git a/csrc/python_frontend/python_bindings.h b/python/python_frontend/python_bindings.h similarity index 100% rename from csrc/python_frontend/python_bindings.h rename to python/python_frontend/python_bindings.h diff --git a/csrc/python_frontend/python_bindings_extension.cpp b/python/python_frontend/python_bindings_extension.cpp similarity index 100% rename from csrc/python_frontend/python_bindings_extension.cpp rename to python/python_frontend/python_bindings_extension.cpp diff --git a/csrc/python_frontend/schedule_bindings.cpp b/python/python_frontend/schedule_bindings.cpp similarity index 100% rename from csrc/python_frontend/schedule_bindings.cpp rename to python/python_frontend/schedule_bindings.cpp diff --git a/csrc/python_frontend/segmentation.cpp b/python/python_frontend/segmentation.cpp similarity index 100% rename from csrc/python_frontend/segmentation.cpp rename to python/python_frontend/segmentation.cpp diff --git a/csrc/python_frontend/segmentation.h b/python/python_frontend/segmentation.h similarity index 100% rename from csrc/python_frontend/segmentation.h rename to python/python_frontend/segmentation.h diff --git a/csrc/python_frontend/translation.cpp b/python/python_frontend/translation.cpp similarity index 100% rename from csrc/python_frontend/translation.cpp rename to python/python_frontend/translation.cpp diff --git a/csrc/python_frontend/translation.h b/python/python_frontend/translation.h similarity index 100% rename from csrc/python_frontend/translation.h rename to python/python_frontend/translation.h diff --git a/csrc/python_frontend/translation_utils.cpp b/python/python_frontend/translation_utils.cpp similarity index 100% rename from csrc/python_frontend/translation_utils.cpp rename to python/python_frontend/translation_utils.cpp diff --git a/csrc/python_frontend/translation_utils.h b/python/python_frontend/translation_utils.h similarity index 100% rename from csrc/python_frontend/translation_utils.h rename to python/python_frontend/translation_utils.h diff --git a/python/setup.py b/python/setup.py new file mode 100644 index 00000000000..4b39f2563fe --- /dev/null +++ b/python/setup.py @@ -0,0 +1,109 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Usage: +# pip install --no-build-isolation -e python -v +# This build command is equivalent to: python setup.py develop +# Options: +# -v: verbose output +# --no-build-isolation: don't build in a temporary directory +# -e: install in development mode +# +# Environment variables used during build: +# MAX_JOBS +# maximum number of compile jobs we should use to compile your code +# +# NVFUSER_BUILD_CMAKE_ONLY +# Only generate ./build directory with cmake setup +# +# NVFUSER_BUILD_NO_PYTHON +# Skips python API target `libnvfuser.so`, i.e. `_C.cpython-xxx.so` +# +# NVFUSER_BUILD_NO_TEST +# Skips cpp tests `test_nvfuser` +# +# NVFUSER_BUILD_NO_BENCHMARK +# Skips benchmark target `nvfuser_bench` +# +# NVFUSER_BUILD_NO_NINJA +# In case you want to use make instead of ninja for build +# +# NVFUSER_BUILD_WITH_UCC +# Build nvfuser with UCC support. You may need to specify environment variables of UCC_HOME, UCC_DIR, UCX_HOME, UCX_DIR. +# +# NVFUSER_BUILD_WITHOUT_DISTRIBUTED +# Build nvfuser without multidevice support +# +# NVFUSER_BUILD_TYPE=Debug +# Building nvfuser in debug mode +# +# NVFUSER_BUILD_TYPE=RelwithDebInfo +# Building nvfuser in release mode with debug info, a.k.a. RelwithDebInfo +# +# NVFUSER_BUILD_DIR= +# Specify in which directory to build nvfuser. If not specified, the default build directory is "./build". +# +# NVFUSER_BUILD_INSTALL_DIR= +# Specify in which directory to install nvfuser. If not specified, the default install directory is "./python/nvfuser". +# +# NVFUSER_BUILD_VERSION_TAG=TAG +# Specify the tag for build nvfuser version, this is used for pip wheel +# package nightly where we might want to add a date tag +# nvfuser-VERSION+TAG+gitSHA1-....-whl +# +# NVFUSER_BUILD_INSTALL_REQUIRES=pkg0[,pkg1...] +# this is used for pip wheel build to specify package required for install +# e.g. NVFUSER_BUILD_INSTALL_REQUIRES=nvidia-cuda-nvrtc-cu12 +# +# NVFUSER_BUILD_WHEEL_NAME=NAME +# Specify the wheel name this is used for pip wheel package where we want +# to identify the cuda toolkit version +# +# NVFUSER_BUILD_CPP_STANDARD=STANDARD +# Specify the C++ standard to use for building nvfuser. The default is C++20. +# + +import sys + +from utils import ( + run, + create_build_config, + override_build_config_from_env, +) + + +def version_tag(config): + from tools.gen_nvfuser_version import get_version + + version = get_version() + if config.overwrite_version: + version = version.split("+")[0] + if len(config.version_tag) != 0: + # use "." to be pypi friendly + version = ".".join([version, config.version_tag]) + return version + + +def main(): + # Parse arguments using argparse + # Use argparse to create description of arguments from command line + config, forward_args = create_build_config() + + # Override build config from environment variables + override_build_config_from_env(config) + + if "clean" in sys.argv: + # only disables BUILD_SETUP, but keep the argument for setuptools + config.build_setup = False + + if config.cpp_standard < 20: + raise ValueError("nvfuser requires C++20 standard or higher") + + sys.argv = [sys.argv[0]] + forward_args + + run(config, version_tag(config), relative_path="..") + + +if __name__ == "__main__": + main() diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp b/python/tests/python_frontend/test_nvfuser_fusion_cache.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_cache.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_cache.cpp diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp b/python/tests/python_frontend/test_nvfuser_fusion_definition.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_definition.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_definition.cpp diff --git a/tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp b/python/tests/python_frontend/test_nvfuser_fusion_record.cpp similarity index 100% rename from tests/cpp/python_frontend/test_nvfuser_fusion_record.cpp rename to python/tests/python_frontend/test_nvfuser_fusion_record.cpp diff --git a/python/tools/__init__.py b/python/tools/__init__.py new file mode 100644 index 00000000000..51ba303bccb --- /dev/null +++ b/python/tools/__init__.py @@ -0,0 +1,3 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause diff --git a/python/tools/gen_nvfuser_version.py b/python/tools/gen_nvfuser_version.py new file mode 100644 index 00000000000..a09eda53539 --- /dev/null +++ b/python/tools/gen_nvfuser_version.py @@ -0,0 +1,75 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +import subprocess +import sys +from pathlib import Path + +UNKNOWN = "Unknown" +nvfuser_root = Path(__file__).parent.parent + + +# note that this root currently is still part of pytorch. +def get_sha() -> str: + try: + return ( + subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=nvfuser_root) + .decode("ascii") + .strip() + ) + except Exception: + import os + + # assume the $NVFUSER_VERSION is in sha form + if nvfuser_version := os.environ.get("NVFUSER_VERSION"): + assert ( + len(nvfuser_version) < 11 + ), "The NVFUSER_VERSION should be in sha form" + return nvfuser_version + return UNKNOWN + + +def get_version() -> str: + sha = get_sha() + version = ( + open((nvfuser_root / "version.txt"), "r").read().strip() + "+git" + sha[:7] + ) + return version + + +def get_pytorch_cmake_prefix(): + from subprocess import Popen, PIPE + + # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch + process_torch_prefix = Popen( + [ + sys.executable, + "-c", + "import torch.utils; print(torch.utils.cmake_prefix_path)", + ], + stdout=PIPE, + ) + stdout_msg, error_msg = process_torch_prefix.communicate() + return stdout_msg.decode("utf-8").rstrip("\n") + + +def get_pytorch_use_distributed(): + from subprocess import Popen, PIPE + + # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch + process_torch_prefix = Popen( + [ + sys.executable, + "-c", + "import torch; print(torch._C._has_distributed())", + ], + stdout=PIPE, + ) + stdout_msg, error_msg = process_torch_prefix.communicate() + return stdout_msg.decode("utf-8").rstrip("\n") + + +if __name__ == "__main__": + version_file = nvfuser_root / "nvfuser" / "version.py" + with open(version_file, "w") as f: + f.write("_version_str = '{}'\n".format(get_version())) diff --git a/python/tools/memory.py b/python/tools/memory.py new file mode 100644 index 00000000000..1ed95f8ded5 --- /dev/null +++ b/python/tools/memory.py @@ -0,0 +1,28 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + + +def get_available_memory_gb(): + """Returns the available memory in GB.""" + try: + import psutil + + return psutil.virtual_memory().available / 1024 / 1024 / 1024 + except: # noqa: E722 + pass + + try: + with open("/proc/meminfo", "r") as f: + while True: + line = f.readline() + if line.startswith("MemAvailable:"): + mem = line.split()[1] + assert line.split()[2] == "kB" + return int(mem) / 1024 / 1024 + if not line: + break + except: # noqa: E722 + pass + + return 0 diff --git a/python/utils.py b/python/utils.py new file mode 100644 index 00000000000..3c13b898f6d --- /dev/null +++ b/python/utils.py @@ -0,0 +1,567 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import argparse +import os +import multiprocessing +import subprocess +import sys +import shutil +from dataclasses import dataclass +import setuptools.command.build_ext + + +@dataclass +class BuildConfig: + cmake_only: bool = False + build_setup: bool = True + no_python: bool = False + no_test: bool = False + no_benchmark: bool = False + no_ninja: bool = False + build_with_ucc: bool = False + build_with_asan: bool = False + build_without_distributed: bool = False + build_with_system_nvtx: bool = True + explicit_error_check: bool = False + overwrite_version: bool = False + version_tag: str = None + build_type: str = "Release" + wheel_name: str = "nvfuser" + build_dir: str = "" + install_dir: str = "" + install_requires: list = None + extras_require: dict = None + cpp_standard: int = 20 + + def __post_init__(self): + # dataclass cannot have mutable default values in the class definition + if self.install_requires is None: + self.install_requires = [] + if self.extras_require is None: + self.extras_require = {} + + +def check_env_flag_bool_default(name: str, default: str = "") -> bool: + if name not in os.environ: + return default + return os.getenv(name).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def get_env_flag_bool(name: str) -> bool: + assert name in os.environ + return os.getenv(name).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + +def parse_args(): + parser = argparse.ArgumentParser( + description="NVFUSER build options", add_help=False + ) + + # Add arguments that don't go to setuptools + parser.add_argument( + "--cmake-only", + dest="cmake_only", + action="store_true", + help="Only generate ./build directory with cmake setup", + ) + parser.add_argument( + "--no-python", + dest="no_python", + action="store_true", + help="Skips python API target libnvfuser.so", + ) + parser.add_argument( + "--no-test", + dest="no_test", + action="store_true", + help="Skips cpp tests test_nvfuser", + ) + parser.add_argument( + "--no-benchmark", + dest="no_benchmark", + action="store_true", + help="Skips benchmark target nvfuser_bench", + ) + parser.add_argument( + "--no-ninja", + dest="no_ninja", + action="store_true", + help="Use make instead of ninja for build", + ) + parser.add_argument( + "--build-with-ucc", + dest="build_with_ucc", + action="store_true", + help="Build nvfuser with UCC support", + ) + parser.add_argument( + "--explicit-error-check", + dest="explicit_error_check", + action="store_true", + help="Enable explicit error checking", + ) + parser.add_argument( + "--build-with-asan", + dest="build_with_asan", + action="store_true", + help="Build with Address Sanitizer", + ) + parser.add_argument( + "--build-without-distributed", + dest="build_without_distributed", + action="store_true", + help="Build nvfuser without multidevice support", + ) + parser.add_argument( + "--no-system-nvtx", + dest="no_system_nvtx", + action="store_true", + help="Disable system NVTX", + ) + parser.add_argument( + "--debug", + dest="debug_mode", + action="store_true", + help="Building nvfuser in debug mode", + ) + parser.add_argument( + "--debinfo", + dest="debinfo_mode", + action="store_true", + help="Building nvfuser in release mode with debug info", + ) + parser.add_argument( + "--build-dir", + dest="build_dir", + type=str, + default="", + help="Specify in which directory to build nvfuser", + ) + parser.add_argument( + "--install-dir", + dest="install_dir", + type=str, + default="", + help="Specify in which directory to install nvfuser", + ) + parser.add_argument( + "-install_requires", + dest="install_requires", + type=str, + help="Specify package required for installation", + ) + parser.add_argument( + "--extras_require", + dest="extras_require", + type=str, + help="Specify extra requirements", + ) + parser.add_argument( + "-version-tag", + dest="version_tag", + type=str, + help="Specify the tag for build nvfuser version", + ) + parser.add_argument( + "-wheel-name", + dest="wheel_name", + type=str, + default="nvfuser", + help="Specify the wheel name", + ) + parser.add_argument( + "--cpp", + dest="cpp_standard", + type=int, + help="Specify the C++ standard to use", + default=20, + ) + + # Use parse_known_args to separate our arguments from setuptools arguments + args, forward_args = parser.parse_known_args() + return args, forward_args + + +# Create BuildConfig using argparse +def create_build_config(): + # Parse arguments and set global variables accordingly + args, forward_args = parse_args() + + # Create a BuildConfig from args + config = BuildConfig( + cmake_only=args.cmake_only, + no_python=args.no_python, + no_test=args.no_test, + no_benchmark=args.no_benchmark, + no_ninja=args.no_ninja, + build_with_ucc=args.build_with_ucc, + build_with_asan=args.build_with_asan, + build_without_distributed=args.build_without_distributed, + build_with_system_nvtx=not args.no_system_nvtx, + explicit_error_check=args.explicit_error_check, + wheel_name=args.wheel_name, + build_dir=args.build_dir, + install_dir=args.install_dir, + cpp_standard=args.cpp_standard, + ) + + # Apply remaining options + if args.debug_mode: + config.build_type = "Debug" + if args.debinfo_mode: + config.build_type = "RelwithDebInfo" + if args.install_requires: + config.install_requires = args.install_requires.split(",") + if args.extras_require: + config.extras_require = eval(args.extras_require) + if args.version_tag: + config.version_tag = args.version_tag + config.overwrite_version = True + return config, forward_args + + +# Override BuildConfig with environment variables. Only change if variable +# exists. Do not use default to override argparse. +def override_build_config_from_env(config): + # Command line arguments don't work on PEP517 builds and will be silently ignored, + # so we need to pass those options as environment variables instead. + if "NVFUSER_BUILD_CMAKE_ONLY" in os.environ: + config.cmake_only = get_env_flag_bool("NVFUSER_BUILD_CMAKE_ONLY") + if "NVFUSER_BUILD_SETUP" in os.environ: + config.build_setup = get_env_flag_bool("NVFUSER_BUILD_SETUP") + if "NVFUSER_BUILD_NO_PYTHON" in os.environ: + config.no_python = get_env_flag_bool("NVFUSER_BUILD_NO_PYTHON") + if "NVFUSER_BUILD_NO_TEST" in os.environ: + config.no_test = get_env_flag_bool("NVFUSER_BUILD_NO_TEST") + if "NVFUSER_BUILD_NO_BENCHMARK" in os.environ: + config.no_benchmark = get_env_flag_bool("NVFUSER_BUILD_NO_BENCHMARK") + if "NVFUSER_BUILD_NO_NINJA" in os.environ: + config.no_ninja = get_env_flag_bool("NVFUSER_BUILD_NO_NINJA") + if "NVFUSER_BUILD_WITH_UCC" in os.environ: + config.build_with_ucc = get_env_flag_bool("NVFUSER_BUILD_WITH_UCC") + if "NVFUSER_BUILD_WITH_ASAN" in os.environ: + config.build_with_asan = get_env_flag_bool("NVFUSER_BUILD_WITH_ASAN") + if "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" in os.environ: + config.build_without_distributed = get_env_flag_bool( + "NVFUSER_BUILD_WITHOUT_DISTRIBUTED" + ) + if "NVFUSER_BUILD_WITH_SYSTEM_NVTX" in os.environ: + config.build_with_system_nvtx = get_env_flag_bool( + "NVFUSER_BUILD_WITH_SYSTEM_NVTX" + ) + if "NVFUSER_BUILD_EXPLICIT_ERROR_CHECK" in os.environ: + config.explicit_error_check = get_env_flag_bool( + "NVFUSER_BUILD_EXPLICIT_ERROR_CHECK" + ) + if "NVFUSER_BUILD_OVERWRITE_VERSION" in os.environ: + config.overwrite_version = get_env_flag_bool("NVFUSER_BUILD_OVERWRITE_VERSION") + if "NVFUSER_BUILD_VERSION_TAG" in os.environ: + config.version_tag = os.getenv("NVFUSER_BUILD_VERSION_TAG") + if "NVFUSER_BUILD_BUILD_TYPE" in os.environ: + config.build_type = os.getenv("NVFUSER_BUILD_BUILD_TYPE") + if "NVFUSER_BUILD_WHEEL_NAME" in os.environ: + config.wheel_name = os.getenv("NVFUSER_BUILD_WHEEL_NAME") + if "NVFUSER_BUILD_DIR" in os.environ: + config.build_dir = os.getenv("NVFUSER_BUILD_DIR") + if "NVFUSER_BUILD_INSTALL_DIR" in os.environ: + config.install_dir = os.getenv("NVFUSER_BUILD_INSTALL_DIR") + if "NVFUSER_BUILD_INSTALL_REQUIRES" in os.environ: + config.install_requires = os.getenv("NVFUSER_BUILD_INSTALL_REQUIRES").split(",") + if "NVFUSER_BUILD_EXTRAS_REQUIRE" in os.environ: + config.extras_require = eval(os.getenv("NVFUSER_BUILD_EXTRAS_REQUIRE")) + if "NVFUSER_BUILD_CPP_STANDARD" in os.environ: + config.cpp_standard = int(os.getenv("NVFUSER_BUILD_CPP_STANDARD")) + if "NVFUSER_BUILD_VERSION_TAG" in os.environ: + config.overwrite_version = True + config.version_tag = os.getenv("NVFUSER_BUILD_VERSION_TAG") + + +class build_ext(setuptools.command.build_ext.build_ext): + def build_extension(self, ext): + if ext.name == "nvfuser._C": + # Copy files on necessity. + filename = self.get_ext_filename(self.get_ext_fullname(ext.name)) + fileext = os.path.splitext(filename)[1] + + libnvfuser_path = os.path.join("./nvfuser/lib", f"libnvfuser{fileext}") + assert os.path.exists(libnvfuser_path) + install_dst = os.path.join(self.build_lib, filename) + if not os.path.exists(os.path.dirname(install_dst)): + os.makedirs(os.path.dirname(install_dst)) + self.copy_file(libnvfuser_path, install_dst) + else: + super().build_extension(ext) + + +class concat_third_party_license: + def __init__(self, directory="third_party"): + self.license_file = "LICENSE" + self.directory = directory + + def __enter__(self): + # read original license file + with open(self.license_file, "r") as f: + self.nvfuser_license_txt = f.read() + + licenses = {"LICENSE", "LICENSE.txt", "LICENSE.rst", "COPYING.BSD"} + + # aggregated license, we key on project name + aggregated_license = {} + for root, dirs, files in os.walk(self.directory): + license = list(licenses & set(files)) + if license: + project_name = root.split("/")[-1] + # let's worry about multiple license when we see it. + assert len(license) == 1 + license_entry = os.path.join(root, license[0]) + if project_name in aggregated_license: + # Only add it if the license is different + aggregated_license[project_name].append(license_entry) + else: + aggregated_license[project_name] = [license_entry] + return aggregated_license + + def __exit__(self, exception_type, exception_value, traceback): + # restore original license file + with open(self.license_file, "w") as f: + f.write(self.nvfuser_license_txt) + + +try: + from wheel.bdist_wheel import bdist_wheel +except ImportError: + build_whl = None +else: + + class build_whl(bdist_wheel): + def run(self): + with concat_third_party_license() as tp_licenses: + if len(tp_licenses) != 0: + with open("LICENSE", "a") as f: + f.write("\n\n") + f.write( + "NVIDIA/fuser depends on libraries with license listed below:" + ) + + for project_name, license_files in tp_licenses.items(): + # check all license files are identical + with open(license_files[0], "r") as f: + license_ref = f.read() + + def check_file(file_name): + with open(file_name, "r") as f: + return f.read() == license_ref + + identical_flag = all(map(check_file, license_files[1:])) + if not identical_flag: + raise RuntimeError( + "inconsistent license found for project: ", + project_name, + " check its license files under: ", + license_files, + ) + + with open("LICENSE", "a") as f: + f.write("\n\nProject Name: " + project_name) + f.write("\nLicense Files:\n") + for file_name in license_files: + f.write("\t" + file_name) + f.write("\n" + license_ref) + + # generate whl before we restore LICENSE + super().run() + + +def get_cmake_bin(): + # TODO: double check cmake version here and retrieve later version if necessary + return "cmake" + + +def cmake(config, relative_path): + from tools.memory import get_available_memory_gb + + # make build directories + cwd = os.path.dirname(os.path.abspath(__file__)) + cmake_build_dir = ( + os.path.join(cwd, "build") if not config.build_dir else config.build_dir + ) + if not os.path.exists(cmake_build_dir): + os.makedirs(cmake_build_dir) + + install_prefix = ( + os.path.join(cwd, "nvfuser") if not config.install_dir else config.install_dir + ) + + from tools.gen_nvfuser_version import ( + get_pytorch_cmake_prefix, + get_pytorch_use_distributed, + ) + + # this is used to suppress import error. + # so we can get the right pytorch prefix for cmake + import logging + + logger = logging.getLogger("nvfuser") + logger_level = logger.getEffectiveLevel() + logger.setLevel(logging.CRITICAL) + + pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() + + logger.setLevel(logger_level) + + pytorch_use_distributed = get_pytorch_use_distributed() + + # generate cmake directory + cmd_str = [ + get_cmake_bin(), + pytorch_cmake_config, + "-DCMAKE_BUILD_TYPE=" + config.build_type, + f"-DCMAKE_INSTALL_PREFIX={install_prefix}", + f"-DNVFUSER_CPP_STANDARD={config.cpp_standard}", + f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", + "-B", + cmake_build_dir, + ] + if config.build_with_ucc: + cmd_str.append("-DNVFUSER_STANDALONE_BUILD_WITH_UCC=ON") + if config.explicit_error_check: + cmd_str.append("-DNVFUSER_EXPLICIT_ERROR_CHECK=ON") + if not config.no_ninja: + cmd_str.append("-G") + cmd_str.append("Ninja") + if not config.no_test: + cmd_str.append("-DBUILD_TEST=ON") + if not config.no_python: + cmd_str.append("-DBUILD_PYTHON=ON") + cmd_str.append(f"-DPython_EXECUTABLE={sys.executable}") + if not config.no_benchmark: + cmd_str.append("-DBUILD_NVFUSER_BENCHMARK=ON") + if config.build_with_asan: + cmd_str.append("-DNVFUSER_BUILD_WITH_ASAN=ON") + if config.build_without_distributed: + cmd_str.append("-DNVFUSER_DISTRIBUTED=OFF") + if config.build_with_system_nvtx: + cmd_str.append("-DUSE_SYSTEM_NVTX=ON") + cmd_str.append(relative_path) + + print(f"Configuring CMake with {' '.join(cmd_str)}") + subprocess.check_call(cmd_str) + + max_jobs = multiprocessing.cpu_count() + mem_gb_per_task = 3 # Currently compilation of nvFuser souce code takes ~3GB of memory per task, we should adjust this value if it changes in the future. + available_mem = get_available_memory_gb() + if available_mem > 0: + max_jobs_mem = int(available_mem / mem_gb_per_task) + max_jobs = min(max_jobs, max_jobs_mem) + + if not config.cmake_only: + # build binary + max_jobs = os.getenv("MAX_JOBS", str(max_jobs)) + print(f"Using {max_jobs} jobs for compilation") + cmd_str = [ + get_cmake_bin(), + "--build", + cmake_build_dir, + "--target", + "install", + "--", + "-j", + max_jobs, + ] + subprocess.check_call(cmd_str) + + +def create_clean(relative_path): + class clean(setuptools.Command): + user_options = [] + + def initialize_options(self): + pass + + def finalize_options(self): + pass + + def run(self): + import glob + + gitignore_path = os.path.join(relative_path, ".gitignore") + assert os.path.exists(gitignore_path) + with open(gitignore_path, "r") as f: + ignores = f.read() + for entry in ignores.split("\n"): + # ignore comment in .gitignore + if len(entry) >= 1 and entry[0] != "#": + for filename in glob.glob(entry): + print("removing: ", filename) + try: + os.remove(filename) + except OSError: + shutil.rmtree(filename, ignore_errors=True) + + return clean + + +def run(config, version_tag, relative_path): + from setuptools import Extension, setup, find_packages + + # NOTE(crcrpar): Deliberately build basically two dynamic libraries here so that they can + # be treated as "nvfuser_package_data". This function call will put the two of "nvfuser" and + # "nvfuser_codegen" into "./nvfuser/lib", and the former will be "nvfuser._C". + if config.build_setup: + cmake(config, relative_path) + if not config.cmake_only: + # NOTE: package include files for cmake + # TODO(crcrpar): Better avoid hardcoding `libnvfuser_codegen.so` + # might can be treated by using `exclude_package_data`. + nvfuser_package_data = [ + "lib/libnvfuser_codegen.so", + "include/nvfuser/*.h", + "include/nvfuser/struct.inl", + "include/nvfuser/C++20/type_traits", + "include/nvfuser/device_lower/*.h", + "include/nvfuser/device_lower/analysis/*.h", + "include/nvfuser/device_lower/pass/*.h", + "include/nvfuser/dynamic_type/*", + "include/nvfuser/dynamic_type/C++20/*", + "include/nvfuser/kernel_db/*.h", + "include/nvfuser/multidevice/*.h", + "include/nvfuser/ops/*.h", + "include/nvfuser/ir/*.h", + "include/nvfuser/python_frontend/*.h", + "include/nvfuser/scheduler/*.h", + "include/nvfuser/serde/*.h", + "include/nvfuser/flatbuffers/*.h", + "include/nvfuser/host_ir/*.h", + "include/nvfuser/id_model/*.h", + "share/cmake/nvfuser/NvfuserConfig*", + # TODO(crcrpar): it'd be better to ship the following two binaries. + # Would need some change in CMakeLists.txt. + # "bin/test_nvfuser", + # "bin/nvfuser_bench" + ] + + setup( + name=config.wheel_name, + version=version_tag, + url="https://github.com/NVIDIA/Fuser", + description="A Fusion Code Generator for NVIDIA GPUs (commonly known as 'nvFuser')", + packages=find_packages(), + ext_modules=[Extension(name="nvfuser._C", sources=[])], + license_files=("LICENSE",), + cmdclass={ + "bdist_wheel": build_whl, + "build_ext": build_ext, + "clean": create_clean(relative_path), + }, + package_data={ + "nvfuser": nvfuser_package_data, + }, + install_requires=config.install_requires, + extras_require={ + "test": ["numpy", "expecttest", "pytest"], + **config.extras_require, + }, + license="BSD-3-Clause", + ) diff --git a/version.txt b/python/version.txt similarity index 100% rename from version.txt rename to python/version.txt diff --git a/setup.py b/setup.py index 4aced7e1a57..e1cd19eb726 100644 --- a/setup.py +++ b/setup.py @@ -1,12 +1,16 @@ # SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause +# +# Usage: +# [MAX_JOBS] python setup.py develop [args] +# # Environment variables used during build: # -# MAX_JOBS +# MAX_JOBS # maximum number of compile jobs we should use to compile your code # -# build argument: +# NvFuser build arguments: # # --cmake-only # Only generate ./build directory with cmake setup @@ -39,7 +43,7 @@ # Specify in which directory to build nvfuser. If not specified, the default build directory is "./build". # # --install-dir= -# Specify in which directory to install nvfuser. If not specified, the default install directory is "./nvfuser". +# Specify in which directory to install nvfuser. If not specified, the default install directory is "./python/nvfuser". # # -version-tag=TAG # Specify the tag for build nvfuser version, this is used for pip wheel @@ -58,395 +62,45 @@ # Specify the C++ standard to use for building nvfuser. The default is C++20. # -import multiprocessing -import os -import shutil -import subprocess -import sys - -import setuptools -import setuptools.command.build_ext -from setuptools import Extension, setup, find_packages - -# pick args used by this script -CMAKE_ONLY = False -BUILD_SETUP = True -NO_PYTHON = False -NO_TEST = False -NO_BENCHMARK = False -NO_NINJA = False -BUILD_WITH_UCC = False -BUILD_WITH_ASAN = False -BUILD_WITHOUT_DISTRIBUTED = False -BUILD_WITH_SYSTEM_NVTX = True -OVERWRITE_VERSION = False -EXPLICIT_ERROR_CHECK = False -VERSION_TAG = None -BUILD_TYPE = "Release" -WHEEL_NAME = "nvfuser" -BUILD_DIR = "" -INSTALL_DIR = "" -INSTALL_REQUIRES = [] -EXTRAS_REQUIRE = {} -CPP_STANDARD = 20 -forward_args = [] -for i, arg in enumerate(sys.argv): - if arg == "--cmake-only": - CMAKE_ONLY = True - continue - if arg == "--no-python": - NO_PYTHON = True - continue - if arg == "--no-test": - NO_TEST = True - continue - if arg == "--no-benchmark": - NO_BENCHMARK = True - continue - if arg == "--no-ninja": - NO_NINJA = True - continue - if arg == "--build-with-ucc": - BUILD_WITH_UCC = True - continue - if arg == "--explicit-error-check": - EXPLICIT_ERROR_CHECK = True - continue - if arg == "--build-with-asan": - BUILD_WITH_ASAN = True - continue - if arg == "--build-without-distributed": - BUILD_WITHOUT_DISTRIBUTED = True - continue - if arg == "--no-system-nvtx": - BUILD_WITH_SYSTEM_NVTX = False - continue - if arg == "--debug": - BUILD_TYPE = "Debug" - continue - if arg == "--debinfo": - BUILD_TYPE = "RelwithDebInfo" - continue - if arg.startswith("--build-dir"): - BUILD_DIR = arg.split("=")[1] - continue - if arg.startswith("--install-dir"): - INSTALL_DIR = arg.split("=")[1] - continue - if arg.startswith("-install_requires="): - INSTALL_REQUIRES = arg.split("=")[1].split(",") - continue - if arg.startswith("--extras_require="): - EXTRAS_REQUIRE = eval("=".join(arg.split("=")[1:])) - continue - if arg.startswith("-version-tag="): - OVERWRITE_VERSION = True - VERSION_TAG = arg.split("=")[1] - continue - if arg.startswith("-wheel-name="): - WHEEL_NAME = arg.split("=")[1] - continue - if arg.startswith("--cpp="): - CPP_STANDARD = int(arg.split("=")[1]) - if CPP_STANDARD < 20: - raise ValueError("nvfuser requires C++20 standard or higher") - continue - if arg in ["clean"]: - # only disables BUILD_SETUP, but keep the argument for setuptools - BUILD_SETUP = False - forward_args.append(arg) -sys.argv = forward_args - - -def get_cmake_bin(): - # TODO: double check cmake version here and retrieve later version if necessary - return "cmake" - - -class clean(setuptools.Command): - user_options = [] - - def initialize_options(self): - pass - - def finalize_options(self): - pass - - def run(self): - import glob - - with open(".gitignore", "r") as f: - ignores = f.read() - for entry in ignores.split("\n"): - # ignore comment in .gitignore - if len(entry) >= 1 and entry[0] != "#": - for filename in glob.glob(entry): - print("removing: ", filename) - try: - os.remove(filename) - except OSError: - shutil.rmtree(filename, ignore_errors=True) - - -class build_ext(setuptools.command.build_ext.build_ext): - def build_extension(self, ext): - if ext.name == "nvfuser._C": - # Copy files on necessity. - filename = self.get_ext_filename(self.get_ext_fullname(ext.name)) - fileext = os.path.splitext(filename)[1] - - libnvfuser_path = os.path.join("./nvfuser/lib", f"libnvfuser{fileext}") - assert os.path.exists(libnvfuser_path) - install_dst = os.path.join(self.build_lib, filename) - if not os.path.exists(os.path.dirname(install_dst)): - os.makedirs(os.path.dirname(install_dst)) - self.copy_file(libnvfuser_path, install_dst) - else: - super().build_extension(ext) - - -class concat_third_party_license: - def __init__(self, directory="third_party"): - self.license_file = "LICENSE" - self.directory = directory - - def __enter__(self): - # read original license file - with open(self.license_file, "r") as f: - self.nvfuser_license_txt = f.read() - - licenses = {"LICENSE", "LICENSE.txt", "LICENSE.rst", "COPYING.BSD"} - - # aggregated license, we key on project name - aggregated_license = {} - for root, dirs, files in os.walk(self.directory): - license = list(licenses & set(files)) - if license: - project_name = root.split("/")[-1] - # let's worry about multiple license when we see it. - assert len(license) == 1 - license_entry = os.path.join(root, license[0]) - if project_name in aggregated_license: - # Only add it if the license is different - aggregated_license[project_name].append(license_entry) - else: - aggregated_license[project_name] = [license_entry] - return aggregated_license +# TODO Remove nvfuser symbolic link to python/nvfuser +# TODO Remove tools/gen_nvfuser_version.py symbolic link to python/tools/gen_nvfuser_version.py +# TODO Remove tools/memory.py symbolic link to python/tools/memory.py - def __exit__(self, exception_type, exception_value, traceback): - # restore original license file - with open(self.license_file, "w") as f: - f.write(self.nvfuser_license_txt) - - -try: - from wheel.bdist_wheel import bdist_wheel -except ImportError: - build_whl = None -else: - - class build_whl(bdist_wheel): - def run(self): - with concat_third_party_license() as tp_licenses: - if len(tp_licenses) != 0: - with open("LICENSE", "a") as f: - f.write("\n\n") - f.write( - "NVIDIA/fuser depends on libraries with license listed below:" - ) - - for project_name, license_files in tp_licenses.items(): - # check all license files are identical - with open(license_files[0], "r") as f: - license_ref = f.read() - - def check_file(file_name): - with open(file_name, "r") as f: - return f.read() == license_ref - - identical_flag = all(map(check_file, license_files[1:])) - if not identical_flag: - raise RuntimeError( - "inconsistent license found for project: ", - project_name, - " check its license files under: ", - license_files, - ) +import sys - with open("LICENSE", "a") as f: - f.write("\n\nProject Name: " + project_name) - f.write("\nLicense Files:\n") - for file_name in license_files: - f.write("\t" + file_name) - f.write("\n" + license_ref) - # generate whl before we restore LICENSE - super().run() +from python.utils import ( + run, + create_build_config, +) -def version_tag(): - from tools.gen_nvfuser_version import get_version +def version_tag(config): + from python.tools.gen_nvfuser_version import get_version version = get_version() - if OVERWRITE_VERSION: + if config.overwrite_version: version = version.split("+")[0] - if len(VERSION_TAG) != 0: + if len(config.version_tag) != 0: # use "." to be pypi friendly - version = ".".join([version, VERSION_TAG]) + version = ".".join([version, config.version_tag]) return version -from tools.memory import get_available_memory_gb - - -def cmake(): - # make build directories - cwd = os.path.dirname(os.path.abspath(__file__)) - cmake_build_dir = os.path.join(cwd, "build") if not BUILD_DIR else BUILD_DIR - if not os.path.exists(cmake_build_dir): - os.makedirs(cmake_build_dir) - - install_prefix = os.path.join(cwd, "nvfuser") if not INSTALL_DIR else INSTALL_DIR - - from tools.gen_nvfuser_version import ( - get_pytorch_cmake_prefix, - get_pytorch_use_distributed, - ) - - # this is used to suppress import error. - # so we can get the right pytorch prefix for cmake - import logging - - logger = logging.getLogger("nvfuser") - logger_level = logger.getEffectiveLevel() - logger.setLevel(logging.CRITICAL) - - pytorch_cmake_config = "-DCMAKE_PREFIX_PATH=" + get_pytorch_cmake_prefix() - - logger.setLevel(logger_level) - - pytorch_use_distributed = get_pytorch_use_distributed() - - # generate cmake directory - cmd_str = [ - get_cmake_bin(), - pytorch_cmake_config, - "-DCMAKE_BUILD_TYPE=" + BUILD_TYPE, - f"-DCMAKE_INSTALL_PREFIX={install_prefix}", - f"-DNVFUSER_CPP_STANDARD={CPP_STANDARD}", - f"-DUSE_DISTRIBUTED={pytorch_use_distributed}", - "-B", - cmake_build_dir, - ] - if BUILD_WITH_UCC: - cmd_str.append("-DNVFUSER_STANDALONE_BUILD_WITH_UCC=ON") - if EXPLICIT_ERROR_CHECK: - cmd_str.append("-DNVFUSER_EXPLICIT_ERROR_CHECK=ON") - if not NO_NINJA: - cmd_str.append("-G") - cmd_str.append("Ninja") - if not NO_TEST: - cmd_str.append("-DBUILD_TEST=ON") - if not NO_PYTHON: - cmd_str.append("-DBUILD_PYTHON=ON") - cmd_str.append(f"-DPython_EXECUTABLE={sys.executable}") - if not NO_BENCHMARK: - cmd_str.append("-DBUILD_NVFUSER_BENCHMARK=ON") - if BUILD_WITH_ASAN: - cmd_str.append("-DNVFUSER_BUILD_WITH_ASAN=ON") - if BUILD_WITHOUT_DISTRIBUTED: - cmd_str.append("-DNVFUSER_DISTRIBUTED=OFF") - if BUILD_WITH_SYSTEM_NVTX: - cmd_str.append("-DUSE_SYSTEM_NVTX=ON") - cmd_str.append(".") - - print(f"Configuring CMake with {' '.join(cmd_str)}") - subprocess.check_call(cmd_str) - - max_jobs = multiprocessing.cpu_count() - mem_gb_per_task = 3 # Currently compilation of nvFuser souce code takes ~3GB of memory per task, we should adjust this value if it changes in the future. - available_mem = get_available_memory_gb() - if available_mem > 0: - max_jobs_mem = int(available_mem / mem_gb_per_task) - max_jobs = min(max_jobs, max_jobs_mem) +def main(): + # Parse arguments using argparse + config, forward_args = create_build_config() - if not CMAKE_ONLY: - # build binary - max_jobs = os.getenv("MAX_JOBS", str(max_jobs)) - print(f"Using {max_jobs} jobs for compilation") - cmd_str = [ - get_cmake_bin(), - "--build", - cmake_build_dir, - "--target", - "install", - "--", - "-j", - max_jobs, - ] - subprocess.check_call(cmd_str) + if "clean" in sys.argv: + # only disables BUILD_SETUP, but keep the argument for setuptools + config.build_setup = False + if config.cpp_standard < 20: + raise ValueError("nvfuser requires C++20 standard or higher") -def main(): - # NOTE(crcrpar): Deliberately build basically two dynamic libraries here so that they can - # be treated as "nvfuser_package_data". This function call will put the two of "nvfuser" and - # "nvfuser_codegen" into "./nvfuser/lib", and the former will be "nvfuser._C". - if BUILD_SETUP: - cmake() - if not CMAKE_ONLY: - # NOTE: package include files for cmake - # TODO(crcrpar): Better avoid hardcoding `libnvfuser_codegen.so` - # might can be treated by using `exclude_package_data`. - nvfuser_package_data = [ - "lib/libnvfuser_codegen.so", - "include/nvfuser/*.h", - "include/nvfuser/struct.inl", - "include/nvfuser/C++20/type_traits", - "include/nvfuser/device_lower/*.h", - "include/nvfuser/device_lower/analysis/*.h", - "include/nvfuser/device_lower/pass/*.h", - "include/nvfuser/dynamic_type/*", - "include/nvfuser/dynamic_type/C++20/*", - "include/nvfuser/kernel_db/*.h", - "include/nvfuser/multidevice/*.h", - "include/nvfuser/ops/*.h", - "include/nvfuser/ir/*.h", - "include/nvfuser/python_frontend/*.h", - "include/nvfuser/scheduler/*.h", - "include/nvfuser/serde/*.h", - "include/nvfuser/flatbuffers/*.h", - "include/nvfuser/host_ir/*.h", - "include/nvfuser/id_model/*.h", - "share/cmake/nvfuser/NvfuserConfig*", - # TODO(crcrpar): it'd be better to ship the following two binaries. - # Would need some change in CMakeLists.txt. - # "bin/test_nvfuser", - # "bin/nvfuser_bench" - ] + sys.argv = [sys.argv[0]] + forward_args - setup( - name=WHEEL_NAME, - version=version_tag(), - url="https://github.com/NVIDIA/Fuser", - description="A Fusion Code Generator for NVIDIA GPUs (commonly known as 'nvFuser')", - packages=find_packages(), - ext_modules=[Extension(name="nvfuser._C", sources=[])], - license_files=("LICENSE",), - cmdclass={ - "bdist_wheel": build_whl, - "build_ext": build_ext, - "clean": clean, - }, - package_data={ - "nvfuser": nvfuser_package_data, - }, - install_requires=INSTALL_REQUIRES, - extras_require={ - "test": ["numpy", "expecttest", "pytest"], - **EXTRAS_REQUIRE, - }, - license="BSD-3-Clause", - ) + run(config, version_tag(config), relative_path=".") if __name__ == "__main__": diff --git a/tests/cpp/test_alias_analysis.cpp b/tests/cpp/test_alias_analysis.cpp index 937167a5517..260172fb4c2 100644 --- a/tests/cpp/test_alias_analysis.cpp +++ b/tests/cpp/test_alias_analysis.cpp @@ -270,7 +270,9 @@ TEST_F(AliasAnalysisTest, BroadcastExpandDimensions) { EXPECT_EQ(analysis.getRoot(expanded_tv), in); } -TEST_F(AliasAnalysisTest, NoAliasForReshardingExprs) { +// See PR: https://github.com/NVIDIA/Fuser/pull/4274 +// for alias analysis for resharding exprs +TEST_F(AliasAnalysisTest, AliasForReshardingExprs) { Fusion fusion; FusionGuard fg(&fusion); @@ -288,7 +290,7 @@ TEST_F(AliasAnalysisTest, NoAliasForReshardingExprs) { fusion.addOutput(out); AliasAnalysisResult analysis = findAliases(&fusion); - EXPECT_TRUE(analysis.getRoot(out) == nullptr); + EXPECT_TRUE(analysis.getRoot(out) == in); } } // namespace nvfuser diff --git a/tests/cpp/test_combined_inner_outer_reduction.cpp b/tests/cpp/test_combined_inner_outer_reduction.cpp index 18a6c439930..88825ae9237 100644 --- a/tests/cpp/test_combined_inner_outer_reduction.cpp +++ b/tests/cpp/test_combined_inner_outer_reduction.cpp @@ -1158,13 +1158,15 @@ TEST_P(TmaWarpSpecializedTest, RMSNormBwd) { __LINE__, __FILE__); } +// batch size is revised to 132*148 which is divisible by sm count on H100 & +// B200 will change back to 32 & 2048 after predicate for 1D TMA is added. INSTANTIATE_TEST_SUITE_P( , TmaWarpSpecializedTest, ::testing::Combine( - testing::Values(true, false), + testing::Values(false), // tmp disable tma warp specialized testing::Values(DataType::Float, DataType::BFloat16), - testing::Values(32, 2048), + testing::Values(132 * 148), ::testing::Range((int64_t)1024, (int64_t)8193, (int64_t)1024)), [](const testing::TestParamInfo& info) -> std::string { diff --git a/tests/cpp/test_expr_sort.cpp b/tests/cpp/test_expr_sort.cpp index 81f63adef36..e9f3a0a7670 100644 --- a/tests/cpp/test_expr_sort.cpp +++ b/tests/cpp/test_expr_sort.cpp @@ -13,6 +13,7 @@ #include #include +#include #include #include @@ -25,6 +26,8 @@ namespace nvfuser { using ExprSortTest = NVFuserTest; using testing::ElementsAre; +using testing::IsTrue; +using testing::Property; using testing::SizeIs; // Indirect normalization pattern with zero-dimensional tensors. Originally @@ -174,7 +177,7 @@ MATCHER_P(UnaryOpTypeIs, unary_op_type, "") { } // namespace -TEST_F(ExprSortTest, SegmentedGroup) { +TEST_F(ExprSortTest, SegmentedGroup_Unary) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); @@ -199,8 +202,8 @@ TEST_F(ExprSortTest, SegmentedGroup) { FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); SegmentedFusion* segmented_fusion = runtime->fusionSegments(); ASSERT_THAT(segmented_fusion->groups(), SizeIs(1)); - SegmentedGroup* group = segmented_fusion->groups().front(); + EXPECT_THAT( group->stablyOrderedExprs(), ElementsAre( @@ -209,4 +212,36 @@ TEST_F(ExprSortTest, SegmentedGroup) { UnaryOpTypeIs(UnaryOpType::Cos))); } +TEST_F(ExprSortTest, SegmentedGroup_Binary_SameOperand) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + TensorView* in = makeSymbolicTensor(1); + TensorView* out = neg(in); + out = add(out, out); + + fusion->addInput(in); + fusion->addOutput(out); + + FusionExecutorCache executor_cache(std::move(fusion)); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA); + at::Tensor in_tensor = at::randn({5}, options); + auto out_tensors = executor_cache.runFusionWithInputs({in_tensor}); + + testValidate( + executor_cache.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + SegmentedFusion* segmented_fusion = runtime->fusionSegments(); + ASSERT_THAT(segmented_fusion->groups(), SizeIs(1)); + SegmentedGroup* group = segmented_fusion->groups().front(); + + EXPECT_THAT( + group->stablyOrderedExprs(), + ElementsAre( + Property(&Expr::isA, IsTrue()), + Property(&Expr::isA, IsTrue()))); +} + } // namespace nvfuser diff --git a/tests/cpp/test_gpu_outer_reduction.cpp b/tests/cpp/test_gpu_outer_reduction.cpp index e4f8c0cc425..7646db171c0 100644 --- a/tests/cpp/test_gpu_outer_reduction.cpp +++ b/tests/cpp/test_gpu_outer_reduction.cpp @@ -2559,7 +2559,7 @@ TEST_F(OuterReductionTest, IterGroupedMultipleReductions) { } // Repro of https://github.com/NVIDIA/Fuser/pull/2766 -TEST_F(NVFuserTest, SmallOuterBlockReductionIssue2766) { +TEST_F(OuterReductionTest, SmallOuterBlockReductionIssue2766) { std::unique_ptr fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(&fusion); @@ -2595,4 +2595,48 @@ TEST_F(NVFuserTest, SmallOuterBlockReductionIssue2766) { testValidate(executor_cache.fusion(), outputs, args, __LINE__, __FILE__); } +TEST_F(OuterReductionTest, SimpleThreadLocalSerialReduction) { + auto fusion_ptr = std::make_unique(); + FusionGuard fg(fusion_ptr.get()); + Fusion& fusion = *fusion_ptr; + + std::vector shape{28, 8192, 128}; + + auto T0 = makeContigConcreteTensor(shape, DataType::BFloat16); + fusion.addInput(T0); + auto T1 = castOp(DataType::Float, T0); + auto T2 = sum(T1, {0}); + fusion.addOutput(T2); + + auto fusion_copy = fusion; + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto at_t0 = at::randn(shape, options); + KernelArgumentHolder args = {at_t0}; + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(args); + + // If thread local reduction is used on the tested GPU, the reduction tv + // should be: [..., rS{7}, iV{x}, rUS{1}, rUR{x}] + auto runtime = executor_cache.getMostRecentKernelRuntime(); + for (auto& params : runtime->schedulerHeuristics()->heuristicsList()) { + if (!params->isA()) { + continue; + } + if (!params->as()->cross_block_inner_reduction) { + Fusion* scheduled_fusion = runtime->executors() + .back() + ->as() + ->compiledKernel() + ->kernel(); + auto redu_tv = scheduler_utils::getReductionTvs(scheduled_fusion).at(0); + EXPECT_TRUE(redu_tv->axis(-4)->isReduction()) + << "Expected redu tv is [..., rS{7}, iV{x}, rUS{1}, rUR{x}], got: " + << redu_tv->toString(); + } + } + + testValidate(&fusion_copy, outputs, {at_t0}, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp new file mode 100644 index 00000000000..b77df002bc6 --- /dev/null +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -0,0 +1,814 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nvfuser { + +namespace hir { + +using HirLowerStreamTest = NVFuserTest; + +TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv = makeContigTensor(2); + hic->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Split) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Merge) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, SingleSetOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleBinaryOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + hic->addInput(tv0); + hic->addInput(tv1); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + // std::unordered_map inputs = {{tv0, input}}; + auto output = hie.runWithInput({{tv0, tv0_input}, {tv1, tv1_input}})[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, TwoSetOps) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + hic->addInput(tv0); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 3); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + hic->addInput(tv0); + hic->addOutput(tv3); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + hic->pushBackTopLevelExprs(tv3->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 5); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(4)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ReductionUnsupported) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Reduction) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_M) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, BatchedMatmul) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_N) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_K) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +// We don's support PostOnStream because it does not support well pre-allocated +// outputs. There is no strong motivation to support PostOnStream +TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { + const std::vector input_sizes = {4, 8, 32}; + const std::vector output_sizes = { + input_sizes.at(1), input_sizes.at(2)}; + + auto get_fusion = [input_sizes]() -> std::unique_ptr { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor(input_sizes); + auto tv1 = add(tv0, tv0); + auto tv2 = sum(tv1, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv2); + return fusion; + }; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto host_unit = IrBuilder::create(get_fusion()); + + IrCloner ir_cloner(hic.get()); + TensorView* input = + ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0)) + ->as(); + TensorView* output = + ir_cloner.clone(host_unit->fusion_to_execute()->outputs().at(0)) + ->as(); + + std::vector inputs = {input}; + std::vector outputs = {output}; + auto post_on_stream = + IrBuilder::create(host_unit, inputs, outputs); + + hic->pushBackTopLevelExprs(post_on_stream); + + hic->addInput(input); + hic->addOutput(output); + + output->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +} // namespace hir + +using MultiDeviceExecutorLowerStreamTest = NVFuserTest; + +TEST_F(MultiDeviceExecutorLowerStreamTest, InputsAreNotStreamParallelized) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv = makeContigTensor(2); + fusion->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Split) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Merge) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOpNonOutermost) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleBinaryOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({tv0_input, tv1_input}))[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, TwoSetOps) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + fusion->addInput(tv0); + fusion->addOutput(tv2); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 3); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + fusion->addInput(tv0); + fusion->addOutput(tv3); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 5); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(4)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ReductionUnsupported) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Reduction) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_M) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, BatchedMatmul) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_N) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_K) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +// We only support Stream parallel type on ops that support pre-allocated +// output, which means they need a special handle in HostIrEvaluator and they +// need to be lowered as a Host Ir Op in the TopLevelExpression, no a +// PostOnStream(HostUnit(.)) See HostIrLower::isLoweredAsStandaloneHostOp and +// the test HirLowerStreamTest.DoNotSupportPostOnStream +TEST_F(MultiDeviceExecutorLowerStreamTest, DoNotSupportPostOnStream) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = + abs(tv0); // arbitrary example of an unsupported op. There is no deep + // reason why we not support it -- if needed we could widen the + // support. But I wanna make sure that an unsupported op do not + // silently fails + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_host_irs.cpp b/tests/cpp/test_host_irs.cpp index 633ebc83504..06c029fb6f6 100644 --- a/tests/cpp/test_host_irs.cpp +++ b/tests/cpp/test_host_irs.cpp @@ -1307,6 +1307,44 @@ TEST_F(HirAlias, ThrowOnInputAlias) { EXPECT_ANY_THROW(HostIrEvaluator hie(std::move(hic))); } +using HirAliasSelectHostIrTest = NVFuserTest; + +TEST_F(HirAliasSelectHostIrTest, SelectingTensor) { + constexpr int64_t ndims = 2; + constexpr int64_t dim = 1; + constexpr int64_t index = 3; + const std::vector input_sizes = {32, 32}; + + ASSERT_LT(dim, ndims); + ASSERT_EQ(input_sizes.size(), ndims); + ASSERT_LT(index, input_sizes.at(dim)); + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + TensorView* in = makeContigTensor(ndims); + TensorView* out = makeContigTensor(ndims - 1); + auto* index_val = IrBuilder::create(index, DataType::Index); + auto* select_op = IrBuilder::create(in, out, dim, index_val); + + hic->addInput(in); + hic->addOutput(out); + hic->pushBackTopLevelExprs(select_op); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0).dtype(torch::kFloat); + auto in_aten = at::randn(input_sizes, options); + std::unordered_map concrete_input_buffers = { + {in, in_aten}}; + + auto out_aten = hie.runWithInput(concrete_input_buffers)[0].as(); + + // validate + auto ref_out = in_aten.select(dim, index); + EXPECT_TRUE(ref_out.equal(out_aten)); +} + using HirSetTest = NVFuserTest; TEST_F(HirSetTest, HostIr) { diff --git a/tests/cpp/test_index_put.cpp b/tests/cpp/test_index_put.cpp new file mode 100644 index 00000000000..024ef910398 --- /dev/null +++ b/tests/cpp/test_index_put.cpp @@ -0,0 +1,120 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include + +#include +#include +#include +#include + +namespace nvfuser { + +struct SizeParams { + int64_t vocab_size; + int64_t hidden_size; + int64_t seq_size; +}; + +std::vector generateSizeOneParams() { + int64_t vocab_size = 1024; + int64_t hidden_size = 3584; + int64_t seq_size = 3000; + std::vector params; + for (bool size_one_vocab : {true, false}) { + for (bool size_one_hidden : {true, false}) { + for (bool size_one_seq : {true, false}) { + int64_t vocab = size_one_vocab ? 1 : vocab_size; + int64_t hidden = size_one_hidden ? 1 : hidden_size; + int64_t seq = size_one_seq ? 1 : seq_size; + params.push_back({vocab, hidden, seq}); + } + } + } + return params; +} + +class IndexPut : public NVFuserFixtureParamTest { + protected: + void SetUp() override { + EnableOptionsGuard::getCurOptions().set(EnableOption::IdModel, {"all"}); + NVFuserTest::SetUp(); + } +}; + +INSTANTIATE_TEST_SUITE_P( + , + IndexPut, + ::testing::ValuesIn(generateSizeOneParams())); + +// Note: The semantics doesn't support broadcast on operands, adding `size 1` +// check just to ensure the ID mapping is done correctly. +TEST_P(IndexPut, AccumulateOpWithBroadcastIDs) { + auto fusion_ptr = std::make_unique(); + Fusion& fusion = *fusion_ptr.get(); + FusionGuard fg(&fusion); + + auto [vocab, hidden, seq] = GetParam(); + + std::vector shape1({seq, hidden}); + std::vector shape2({seq, 1}); + + auto tv_value = makeSymbolicTensor(shape1); + fusion.addInput(tv_value); + auto tv_index = makeSymbolicTensor(shape2, DataType::Int); + fusion.addInput(tv_index); + auto s_vocab = IrBuilder::create(vocab, DataType::Index); + std::vector buffer_size = { + s_vocab, tv_value->axis(-1)->extent()}; + auto buf = zeros(buffer_size, DataType::Float, true); + // TODO: this should be an inplace. handle it when we have codegen support + auto out = indexPutAccumulate(buf, tv_index, tv_value); + fusion.addOutput(out); + + // check PairwiseLogicalDomainMap check if tv0 and tv1 map pairwise on + // position according to `expect_to_map` + auto map_logical = [](const std::vector& expect_to_map, + TensorView* tv0, + TensorView* tv1) { + std::unordered_map pairwise_map = + PairwiseLogicalDomainMap(tv0, tv1).mapProducerToConsumer(); + for (auto index : arange(expect_to_map.size())) { + IterDomain* id0 = tv0->getLogicalDomain().at(index); + IterDomain* id1 = tv1->getLogicalDomain().at(index); + EXPECT_EQ( + pairwise_map.find(id0) != pairwise_map.end() && + pairwise_map[id0] == id1, + expect_to_map[index]); + } + }; + + // see [ Note -- IndexPutAccumulateOp semantics ] + // args: + // buf [ ID_indexed_g0, ID_g0 ] + // tv_index [ ID_indexing_g1, ID_broadcast ] + // tv_value [ ID_indexing_g1, ID_g0 ] + // output: + // out [ ID_indexed_g0, ID_g0 ] + map_logical({true, true}, buf, out); + // depends on the size of ID_g0, it would map to ID_broadcast when hidden is + // size-1 dimension + map_logical({false, hidden == 1}, tv_index, out); + map_logical({false, true}, tv_value, out); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto options_i = at::TensorOptions().dtype(at::kLong).device(at::kCUDA, 0); + auto t_value = at::randn(shape1, options); + auto t_index = at::randint(0, vocab, shape2, options_i); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t_value, t_index}); + + testValidate(&fusion, outputs, {t_value, t_index}, __LINE__, __FILE__); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 14a3ef84d62..8a7e6c3782b 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3656,6 +3656,32 @@ class HopperMatmulTest : public HopperBase { } }; +// 2 math group, non-persistent, non-warp specialized, no CGA +// TODO: This could be in HopperMatmulTest::SetUp() instead +MatmulParams defaultHopperParams() { + MatMulTileOptions gemm_tile; + gemm_tile.cta_tile = GemmTile(128, 256, 64); + gemm_tile.warp_tile = GemmTile(64, 256, 64); + MatmulParams mparams; + mparams.supported_vec_size = {8, 8, 8}; + mparams.mma_macro = MmaMacro::Hopper_64_256_16; + mparams.tile_sizes = gemm_tile; + mparams.circular_buffering_strategy = + MatmulParams::CircularBufferingStrategy::Pipelined; + mparams.tiling_strategy = MatmulParams::TilingStrategy::OneTilePerCTA; + mparams.cta_order = MatmulParams::TileRasterizationOrder::ColumnMajor; + mparams.async_gmem_load_operands = true; + mparams.circular_buffer_options.circular_buffer_smem_write = true; + mparams.circular_buffer_options.circular_buffer_smem_read = false; + mparams.circular_buffer_options.smem_circular_buffer_stage = 4; + mparams.circular_buffer_options.smem_circular_buffer_prefetch_gap = 1; + mparams.splitk_factor = 1; + mparams.use_smem_epilogue = true; + mparams.cluster_dims = {1, 1, 1}; + mparams.promote_prologue_smem_reuse = true; + return mparams; +} + TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { Fusion fusion; FusionGuard fg(&fusion); @@ -5360,4 +5386,60 @@ TEST_F(HopperMatmulTest, HSH_NT_SingleMathGroupSyncCheck) { cg_outputs[0].as(), out_ref, 1e-6 * K, 1e-6 * K)); } +// See https://github.com/NVIDIA/Fuser/issues/4159 +TEST_F(HopperMatmulTest, HSS_NT_SplitKTMAStore) { + Fusion fusion; + FusionGuard fg(&fusion); + + constexpr int64_t M = 2048, N = 2048, K = 8192; + const auto dtype = DataType::Half; + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}, dtype); // K, M + auto tv1 = makeContigConcreteTensor({-1, 1, -1}, dtype); // K, N + fusion.addInput(tv0); + fusion.addInput(tv1); + + auto tv2 = fusedMultiplySum(tv0, tv1, {0}); + + // Reorder the accumulator as [M, N, K] + // [K, M, N] -> [M, N, K] + tv2->reorder({{-3, -1}}); + tv2->commitLeafToLogical(); + + fusion.addOutput(tv2); + + auto options = at::TensorOptions().dtype(at::kHalf).device(at::kCUDA); + auto t0 = at::randn({K, M, 1}, options); + auto t1 = at::randn({K, 1, N}, options); + auto out_ref = + at::matmul(t0.squeeze().t().to(at::kFloat), t1.squeeze().to(at::kFloat)); + + MatmulParams mparams = defaultHopperParams(); + mparams.use_smem_epilogue = true; + mparams.splitk_factor = 2; + + SchedulerEntry::makeSchedulerInstance(SchedulerType::Matmul) + ->schedule(&fusion, &mparams); + + KernelExecutor ke; + ke.compile(&fusion, {t0, t1}); + // TODO: Either enable stmatrix for 32-bit outputs or fix current 2-way bank + // conflict by scheduling the vectorized store properly + auto bank_conflicts = getBankConflictInfo(ke.compiledKernel()->kernel()); + EXPECT_EQ(bank_conflicts.size(), 1); + for (const auto& [expr, conflict_ways] : bank_conflicts) { + int64_t input_ways, output_ways; + std::tie(input_ways, output_ways) = conflict_ways; + EXPECT_EQ(input_ways, 0); + EXPECT_EQ(output_ways, 2); + } + auto cg_outputs = ke.run({t0, t1}); + ASSERT_FALSE(PredicatedChecker::isCpAsyncMmaPredicatedByIfThenElse( + ke.compiledKernel()->kernel())); + + // Relax tolerance for larger sum due to large K + NVF_CHECK(at::allclose( + cg_outputs[0].as(), out_ref, 1e-6 * K, 1e-6 * K)); +} + } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 88286d6e4c0..db53f7f114d 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -362,6 +363,10 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + // Disable StreamParallelType pass temporarily as proper stream lowering gets + // implemented + preseg_passes::OptimizationPassGuard guard(false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024; @@ -417,6 +422,9 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { } TEST_F(OverlapDistributedMatmulTest, AG_linear) { + // Disable StreamParallelType pass tempor + preseg_passes::OptimizationPassGuard guard(false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024; diff --git a/tests/cpp/test_multidevice_ipc.cpp b/tests/cpp/test_multidevice_ipc.cpp index 30daf6db145..ba574c0f676 100644 --- a/tests/cpp/test_multidevice_ipc.cpp +++ b/tests/cpp/test_multidevice_ipc.cpp @@ -34,7 +34,7 @@ TEST_F(IpcTest, IpcMemHandle) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // Allocate and setup GPU buffers constexpr size_t kBufferSize = sizeof(int64_t); const int64_t num_devices = communicator_->size(); @@ -75,16 +75,13 @@ TEST_F(IpcTest, IpcMemHandle) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // TL;DR: We can do pointer arithmetic on the importer side. IOW, the pointer // can be used as a regular pointer on the importer side. @@ -131,16 +128,13 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtReceiver) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) { if (communicator_->size() == 1) { GTEST_SKIP() << "Skipping test for single device"; } -#ifdef NVFUSER_DISTRIBUTED + // TL;DR: We CANNOT do pointer arithmetic on the exporter side! The IPC handle // points to the beginning of the allocated buffer. @@ -189,9 +183,6 @@ TEST_F(IpcTest, IpcMemHandlePtrArithmeticAtSender) { // Clean up NVFUSER_CUDA_RT_SAFE_CALL(cudaIpcCloseMemHandle(peer_d_ptr)); NVFUSER_CUDA_RT_SAFE_CALL(cudaFree(d_ptr)); -#else // NVFUSER_DISTRIBUTED - GTEST_SKIP() << "NVFUSER_DISTRIBUTED is not defined"; -#endif // NVFUSER_DISTRIBUTED } // cuStreamWriteValue32 and cuStreamWaitValue32 are CUDA driver API used in the diff --git a/tests/cpp/test_multidevice_matmul.cpp b/tests/cpp/test_multidevice_matmul.cpp index b8aba89aa86..ee24479a0af 100644 --- a/tests/cpp/test_multidevice_matmul.cpp +++ b/tests/cpp/test_multidevice_matmul.cpp @@ -238,7 +238,7 @@ TEST_F(DistributedMatmulTest, Matmul_LayoutTN_Allgather) { executor_cache.getMostRecentKernelRuntime(); EXPECT_THAT( kernel_runtime->fusionSegments()->groups(), - Contains(HeuristicIs(SchedulerType::ExprEval)).Times(2)); + Contains(HeuristicIs(SchedulerType::ExprEval)).Times(3)); } TEST_F(DistributedMatmulTest, Matmul_LayoutNT_AllReduce) { diff --git a/tests/cpp/test_multidevice_sharding.cpp b/tests/cpp/test_multidevice_sharding.cpp index 2309dc4cd36..5a26d5d6622 100644 --- a/tests/cpp/test_multidevice_sharding.cpp +++ b/tests/cpp/test_multidevice_sharding.cpp @@ -894,159 +894,4 @@ TEST_F(MultiDeviceTest, LoopShardedMergeReshapeIds) { __FILE__); } -namespace { -// This is a simplified version of what we will eventually do in the -// pre-segmentation pass -void propagateShardings(Fusion* fusion, int64_t num_devices) { - for (Expr* expr : fusion->exprs()) { - if (expr->isA()) { - NVF_THROW("SliceOp is not currently supported"); - } - - if (expr->isA()) { - // TransformPropagator cannot be directly used. - // It raises an error for conflicting transformations from root domain to - // logical domain. Instead, we manually find the reshaped iterdomain and - // outer split DID. This might have to be extended further in the - // presegmentation pass. - // Note: For simplicity, this assumes that the sharding is on reshaped - // IDs. It is possible that the non-reshaped IDs are sharded, in which - // case we can use the TransformPropagator. - TensorView* reshaped_tv = expr->as()->out(); - auto transform_exprs = StmtSort::getExprsBetween( - {reshaped_tv->getMaybeRootDomain().begin(), - reshaped_tv->getMaybeRootDomain().end()}, - {reshaped_tv->getLogicalDomain().begin(), - reshaped_tv->getLogicalDomain().end()}); - NVF_CHECK(transform_exprs.size() == 1); - auto transform = transform_exprs[0]; - NVF_CHECK(transform->isA() || transform->isA()); - - // Get the reshaped ID (outer ID for split reshape). - // This is the ID that will be parallelized. - IterDomain* reshaped_id = transform->isA() - ? transform->as()->outer() - : transform->as()->out(); - - auto reshaped_it = std::find( - reshaped_tv->getLoopDomain().begin(), - reshaped_tv->getLoopDomain().end(), - reshaped_id); - int64_t reshaped_axis = - std::distance(reshaped_tv->getLoopDomain().begin(), reshaped_it); - - // Apply sharding to the reshaped tensor - reshaped_tv->split(reshaped_axis, num_devices, false); - reshaped_tv->axis(reshaped_axis)->parallelize(ParallelType::DIDx); - reorderDIDToFront(reshaped_tv); - continue; - } - - // For other ops, propagate sharding from input to outputs - auto input_tv = expr->input(0)->as(); - std::vector output_tvs; - for (auto output : expr->outputs()) { - output_tvs.push_back(output->as()); - } - - TransformPropagator propagator(input_tv); - - // Note: We will finally propagate from each input iteratively. - SetSelector selector( - std::unordered_set(output_tvs.begin(), output_tvs.end())); - MaxLogicalDomainInfoSpanningTree(input_tv, &selector).traverse(&propagator); - shardAllLike(input_tv, output_tvs); - } -} - -} // namespace - -TEST_F(MultiDeviceTest, TransformerFwd) { - auto fusion = std::make_unique(); - FusionGuard fg(fusion.get()); - - const int d = communicator_->size(); - const int64_t b = 2, s = 3, h = 8, e = 16; - auto mesh = DeviceMesh::createForNumDevices(d); - - std::vector in_shape = {b, s, d * h * e}; - std::vector out_shape = {b, s, d * h, e}; - - // The transformer block produces hq/hk/hv after slicing the MHA linear - // output. - TensorView* hq = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hk = makeConcreteTensor(in_shape, DataType::Half); - TensorView* hv = makeConcreteTensor(in_shape, DataType::Half); - - TensorView* q = reshape(hq, in_shape, out_shape); - TensorView* q_permuted = permute(q, {0, 2, 1, 3}); - TensorView* k = reshape(hk, in_shape, out_shape); - TensorView* k_permuted = permute(k, {0, 2, 1, 3}); - TensorView* v = reshape(hv, in_shape, out_shape); - TensorView* v_permuted = permute(v, {0, 2, 1, 3}); - - SdpfaFwdResult sdpa_out = sdpfa_fwd( - q_permuted, - k_permuted, - v_permuted, - /*dropout_p=*/IrBuilder::create(0.0), - /*is_causal=*/IrBuilder::create(false), - /*scale=*/nullptr); - - TensorView* attn = sdpa_out.output; - TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); - TensorView* out = reshape(attn_permute, out_shape, in_shape); - - fusion->addInput(hq); - fusion->addInput(hk); - fusion->addInput(hv); - fusion->addOutput(out); - - // Shard input tensors - for (auto* tv : {hq, hk, hv}) { - tv->setDeviceMesh(mesh); - tv->split(-1, d, /*inner_split=*/false); - tv->axis(-2)->parallelize(ParallelType::DIDx); - reorderDIDToFront(tv); - } - propagateShardings(fusion.get(), d); - - for (auto tv : fusion->allTvs()) { - tv->setAllocationDomain(tv->getLoopDomain(), true); - } - - FusionExecutorCache executor_cache(std::move(fusion)); - at::Tensor hq_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hk_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - at::Tensor hv_tensor = at::randn({in_shape}, tensor_options.dtype(at::kHalf)); - - at::Tensor sharded_hq = shardTensor(hq_tensor, -1, mesh); - at::Tensor sharded_hk = shardTensor(hk_tensor, -1, mesh); - at::Tensor sharded_hv = shardTensor(hv_tensor, -1, mesh); - - auto nvf_out = - executor_cache - .runFusionWithInputs({sharded_hq, sharded_hk, sharded_hv})[0] - .as(); - - double scale = 1.0 / std::sqrt(e); - auto reference_out = at::_scaled_dot_product_flash_attention( - hq_tensor.view(out_shape).transpose(1, 2), - hk_tensor.view(out_shape).transpose(1, 2), - hv_tensor.view(out_shape).transpose(1, 2), - /*dropout_p=*/0.0, - /*is_causal=*/false, - /*return_debug_mask=*/false, - scale); - at::Tensor ref_attn = shardTensor( - std::get<0>(reference_out).transpose(1, 2).view(in_shape), -1, mesh); - - testValidate( - executor_cache.fusion(), - {nvf_out}, - {sharded_hq, sharded_hk, sharded_hv}, - {ref_attn}, - __LINE__, - __FILE__); -} } // namespace nvfuser diff --git a/tests/cpp/test_multidevice_transformer.cpp b/tests/cpp/test_multidevice_transformer.cpp index 55f23bdc5a1..025726698be 100644 --- a/tests/cpp/test_multidevice_transformer.cpp +++ b/tests/cpp/test_multidevice_transformer.cpp @@ -1016,6 +1016,144 @@ TEST_P(DistributedTransformerTest, Backward) { 0.02}); } +namespace { +at::Tensor reference_loop_split_mlp( + at::Tensor inp, + at::Tensor w0, + at::Tensor w1) { + auto linear0 = at::linear(inp, w0); + auto gelu = at::gelu(linear0, "tanh"); + auto linear1 = at::linear(gelu, w1); + return linear1; +} + +at::Tensor reference_loop_split_mha(at::Tensor inp) { + auto qkv = inp.transpose(1, 2).split(E / H, -1); + double scale = 1.0 / std::sqrt(E / H); + auto sdpa_out = at::_scaled_dot_product_flash_attention( + qkv[0], + qkv[1], + qkv[2], + /*dropout_p=*/kDropoutProb, + /*is_causal=*/true, + /*return_debug_mask=*/false, + scale); + auto attn = std::get<0>(sdpa_out); + return attn.transpose(1, 2); +} +} // namespace + +// TODO: Allow testing for float16 and bfloat16 for loop split mlp and mha +// This currently fails because privatizeUpcast clones cast operations, +// which fails segmentation since the transforms are not replicated. +TEST_F(DistributedTransformerTest, LoopSplitMLP) { + if ((4 * E) % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide 4*E=" << 4 * E; + } + auto dtype = DataType::Float; + at::ScalarType at_dtype = data_type_to_aten(dtype); + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int d = communicator_->size(); + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* inp = makeContigConcreteTensor({B, S, E}, dtype); + TensorView* w0 = makeContigConcreteTensor({4 * E, E}, dtype); + TensorView* w1 = makeContigConcreteTensor({E, 4 * E}, dtype); + + TensorView* linear0 = linear(inp, w0); + TensorView* linear0_float = castOp(DataType::Float, linear0); + TensorView* gelu = tanh_gelu(linear0_float); + TensorView* gelu_dtype = castOp(dtype, gelu); + TensorView* linear1 = linear(gelu_dtype, w1); + + std::vector fusion_inputs{inp, w0, w1}; + for (auto tv : fusion_inputs) { + fusion->addInput(tv); + tv->setDeviceMesh(mesh); + } + fusion->addOutput(linear1); + + w0->outer_split(0, d); + w0->axis(0)->parallelize(ParallelType::DIDx); + w1->outer_split(1, d); + w1->axis(1)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor inp_tensor = at::randn({B, S, E}, tensor_options.dtype(at_dtype)); + at::Tensor w0_tensor = at::randn({4 * E, E}, tensor_options.dtype(at_dtype)); + at::Tensor w1_tensor = at::randn({E, 4 * E}, tensor_options.dtype(at_dtype)); + + at::Tensor w0_sharded = shardTensor(w0_tensor, 0, mesh); + at::Tensor w1_sharded = shardTensor(w1_tensor, 1, mesh); + + KernelArgumentHolder args = {inp_tensor, w0_sharded, w1_sharded}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + + at::Tensor ref_out = + reference_loop_split_mlp(inp_tensor, w0_tensor, w1_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + +TEST_F(DistributedTransformerTest, LoopSplitMHAFwd) { + if (H % D != 0) { + GTEST_SKIP() << "Requires number of devices=" << D + << " evenly divide H=" << H; + } + + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto dtype = DataType::Half; + at::ScalarType at_dtype = data_type_to_aten(dtype); + + const int d = communicator_->size(); + + auto mesh = DeviceMesh::createForNumDevices(d); + + TensorView* qkv = makeContigConcreteTensor({B, S, H, 3 * E / H}, dtype); + TensorView* q = slice(qkv, {0, 0, 0, 0}, {B, S, H, E / H}); + TensorView* k = slice(qkv, {0, 0, 0, E / H}, {B, S, H, 2 * E / H}); + TensorView* v = slice(qkv, {0, 0, 0, 2 * E / H}, {B, S, H, 3 * E / H}); + + TensorView* q_permuted = permute(q, {0, 2, 1, 3}); + TensorView* k_permuted = permute(k, {0, 2, 1, 3}); + TensorView* v_permuted = permute(v, {0, 2, 1, 3}); + + SdpfaFwdResult sdpa_out = sdpfa_fwd( + q_permuted, + k_permuted, + v_permuted, + /*dropout_p=*/IrBuilder::create(kDropoutProb), + /*is_causal=*/IrBuilder::create(true), + /*scale=*/nullptr); + + TensorView* attn = sdpa_out.output; + TensorView* attn_permute = permute(attn, {0, 2, 1, 3}); + + fusion->addInput(qkv); + fusion->addOutput(attn_permute); + + qkv->setDeviceMesh(mesh); + qkv->outer_split(2, d); + qkv->axis(2)->parallelize(ParallelType::DIDx); + + FusionExecutorCache executor_cache(std::move(fusion)); + at::Tensor unsharded_inp_tensor = + at::randn({B, S, H, 3 * E / H}, tensor_options.dtype(at_dtype)); + at::Tensor inp_tensor = shardTensor(unsharded_inp_tensor, 2, mesh); + + KernelArgumentHolder args = {inp_tensor}; + auto outputs = executor_cache.runFusionWithInputs(args); + at::Tensor nvf_out = outputs[0].as(); + at::Tensor ref_out = reference_loop_split_mha(inp_tensor); + validate({ref_out}, {nvf_out}, {0.02}); +} + INSTANTIATE_TEST_SUITE_P( , DistributedTransformerTest, diff --git a/tests/cpp/test_reduction_pointwise.cpp b/tests/cpp/test_reduction_pointwise.cpp index 98be573ac83..c0ae0c0a65f 100644 --- a/tests/cpp/test_reduction_pointwise.cpp +++ b/tests/cpp/test_reduction_pointwise.cpp @@ -158,4 +158,106 @@ TEST_F(NVFuserTest, InnerReductionUnrollVectorization) { testValidate(&fusion_copy, cg_outputs, {t0}, __LINE__, __FILE__); } +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalID) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + // tv0 [ b0, i1 ] + auto tv0 = makeContigConcreteTensor({1, -1}); + fusion.addInput(tv0); + // tv1 [ i2, i1 ] + // current scheduler picks tv0 as the reference TV, transformations are + // propagated to other TVs. + auto tv1 = makeContigTensor(2); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {0, 1}); + fusion.addOutput(tv2); + auto tv3 = add(tv0, tv1); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1, 100}, options); + auto t1 = at::randn({5, 100}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); +} + +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalIDInnerNormalization) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({-1, -1, 1}); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {1, 2}, /*keep_dim=*/true); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + auto tv4 = add(tv0, tv1); + fusion.addOutput(tv4); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({100, 20, 1}, options); + auto t1 = at::randn({100, 20, 128}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); +} + +// https://github.com/NVIDIA/Fuser/issues/3811 +TEST_F(NVFuserTest, ReductionSchedulerWithAdditionalIDOuterNormalization) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeContigConcreteTensor({1, -1, -1}); + fusion.addInput(tv0); + auto tv1 = makeContigTensor(3); + fusion.addInput(tv1); + + auto tv2 = sum(tv0, {0, 1}, /*keep_dim=*/true); + auto tv3 = add(tv0, tv2); + fusion.addOutput(tv3); + auto tv4 = add(tv0, tv1); + fusion.addOutput(tv4); + + fusion.printMath(); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1, 20, 100}, options); + auto t1 = at::randn({128, 20, 100}, options); + std::vector inputs({t0, t1}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // checking segmentation + auto optimized_fusion = executor_cache.getMostRecentKernelRuntime(); + NVF_CHECK(optimized_fusion->isSegmented(), "segmentation didn't happen!"); + + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); +} + } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 12ca31a929f..4aaac477a52 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5970,7 +5970,7 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) { } } -TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { +TEST_F(ResizeTest, VectorizeInnerSliceMultiplePaths) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -6005,6 +6005,50 @@ TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); } +// The current analysis is not precise enough to pass this test +TEST_F(ResizeTest, DISABLED_VectorizeOuterSliceMultiplePaths) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape{4, 1024 * 1024}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(2), + IrBuilder::create(2)}); + auto tv2 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(4)}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + auto outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); + + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto out_tv = tv3; + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 4); +} + // Repro of issue #4202 TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { auto fusion_ptr = std::make_unique(); @@ -6040,4 +6084,49 @@ TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { testValidate(&fusion, outputs.outputs, {t0, t1}, __LINE__, __FILE__); } +// Check if vectorization is properly applied even when a resized ID +// is reachable from vectorized IDs. Pattern extracted from Litgpt +// LLama RoPE backward. +TEST_F(ResizeTest, VectorizeOuterPad) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape1{1, 8, 4, 8192, 128}; + const std::vector shape2{1, 8, 1, 8192, 128}; + auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv2); + + // [1, 8, 6, 8192, 128] + auto tv3 = cat({tv0, tv1, tv2}, 2); + // [1, 8192, 8, 6, 128] + auto tv4 = permute(tv3, {0, 3, 1, 2, 4}); + auto tv5 = reshape(tv4, {1, 8192, 8, 6, 128}, {1, 8192, 6144}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t1 = at::randn(shape2, options); + auto t2 = at::randn(shape2, options); + + auto outputs = + scheduleAndRun(&fusion, SchedulerType::PointWise, {t0, t1, t2}); + testValidate(&fusion, outputs.outputs, {t0, t1, t2}, __LINE__, __FILE__); + + auto out_tv = tv5; + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 8); +} + } // namespace nvfuser diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 27ae0dbefbd..acb65bff7f6 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -695,7 +695,7 @@ TEST_F(SegmentationTest, ForwardInputsToSegmenterSetIssue2658) { } // Test to verify an upcast is replicated between different segments -TEST_F(NVFuserTest, PrivatizeUpcast) { +TEST_F(SegmentationTest, PrivatizeUpcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -741,7 +741,7 @@ TEST_F(NVFuserTest, PrivatizeUpcast) { // Unlike PrivatizeUpcast, verify replicated upcast ops are // consolidated back as they are grouped into the same segment -TEST_F(NVFuserTest, RevertPrivatizedUpcast) { +TEST_F(SegmentationTest, RevertPrivatizedUpcast) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -750,19 +750,31 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { fusion.addInput(tv0); auto tv1 = segment_set(tv0); - auto tv2 = castOp(DataType::Float, tv1); - auto tv3 = sum(tv2, {1}); - fusion.addOutput(tv3); + auto tv2 = set(tv1); + auto tv3 = castOp(DataType::Float, tv2); - auto tv4 = sum(tv2, {1}); + auto tv4 = sum(tv3, {1}); fusion.addOutput(tv4); + auto tv5 = sum(tv3, {1}); + fusion.addOutput(tv5); + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); auto t0 = at::randn({16, 32}, options); FusionExecutorCache executor_cache(std::move(fusion_ptr)); - auto outputs = executor_cache.runFusionWithInputs({t0}); + KernelArgumentHolder outputs; + + // Make sure NVFUSER_DUMP=segmented_fusion works + { + DebugDumpOptionsGuard options_guard; + DebugDumpOptionsGuard::getCurOptions().set(DebugDumpOption::FusionSegments); + std::ostringstream tmp_buf; + DebugStreamGuard debug_stream_guard(tmp_buf); + outputs = executor_cache.runFusionWithInputs({t0}); + } + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); // There must be two segments, one with ExprEvalExecutor and another @@ -787,7 +799,7 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { continue; } - EXPECT_EQ(uop->in()->as()->view()->name(), 1); + EXPECT_EQ(uop->in()->as()->view()->name(), 2); ++num_upcast_ops; } @@ -795,4 +807,59 @@ TEST_F(NVFuserTest, RevertPrivatizedUpcast) { } } +TEST_F(SegmentationTest, ForwardFull) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + // FullOp that is used in two segments + auto tv1 = full({tv0->axis(0)->extent()}, fusion.oneVal(), DataType::Float); + + auto tv2 = add(tv0, tv1); + auto tv3 = segment_set(tv2); + + auto tv4 = add(tv3, tv1); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn({1024}, options); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs({t0}); + testValidate(&fusion, outputs, {t0}, __LINE__, __FILE__); + + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(2)); + + // Make sure the full output should not be a segment input + for (const auto& executor : runtime->executors()) { + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + bool full_op_found = false; + for (auto expr : KernelExprVisitor::getAllExprs(kernel)) { + auto out_tv = ir_utils::getTvOutput(expr); + if (out_tv == nullptr) { + continue; + } + auto full_op = dynamic_cast(out_tv->definition()); + if (full_op == nullptr) { + continue; + } + full_op_found = true; + auto output_it = + std::ranges::find_if(kernel->outputs(), [&](Val* output) { + return output->isA() && + output->name() == out_tv->name(); + }); + EXPECT_EQ(output_it, kernel->outputs().end()) + << "FullOp ouput should not be a segment output"; + } + EXPECT_TRUE(full_op_found) << "Each segment has its own FullOp"; + } +} + } // namespace nvfuser diff --git a/tests/cpp/test_sharding.cpp b/tests/cpp/test_sharding.cpp index 1ce1d96d8d0..ffbbabb7402 100644 --- a/tests/cpp/test_sharding.cpp +++ b/tests/cpp/test_sharding.cpp @@ -234,6 +234,41 @@ TEST_F(ShardingTest, MultiDimDeviceMesh) { EXPECT_EQ(mesh3d.getSlice(18, ParallelType::DIDx), slice_didx); } +TEST_F(ShardingTest, ResidualAdd) { + // This is similar to the residual add after MHA dropout in the transformer. + // The output of linear following MHA is all-gathered and sharded on the + // sequence dim. This sharding can be propagated to the linear output through + // backpropagating the shardings from residual add. This information is not + // present during forward propagation. + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + DeviceMesh mesh({0, 1}); + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = uniform( + shape(tv0), + fusion->zeroVal(DataType::Float), + fusion->oneVal(DataType::Float), + DataType::Float); + TensorView* tv2 = add(tv0, tv1); + + tv0->setDeviceMesh(mesh); + tv0->outer_split(0, mesh.size()); + tv0->axis(0)->parallelize(ParallelType::DIDx); + + fusion->addInput(tv0); + fusion->addOutput(tv1); + fusion->addOutput(tv2); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + NVF_CHECK(tv1->hasDeviceMesh()); + NVF_CHECK( + getShardedLogicalAxis(tv1, ParallelType::DIDx) == + getShardedLogicalAxis(tv0, ParallelType::DIDx), + "Expected tv1 to be sharded like tv0 due to backpropagation of shardings."); +} + INSTANTIATE_TEST_SUITE_P( , ShardingTest, diff --git a/tests/cpp/test_tmem.cpp b/tests/cpp/test_tmem.cpp index 547b0de840f..ab2b7aebfcb 100644 --- a/tests/cpp/test_tmem.cpp +++ b/tests/cpp/test_tmem.cpp @@ -290,7 +290,7 @@ TEST_F(TMemTestCompileOnly, SetTMemDimSepPosNonTMem) { // But in the TMem load/store's loop domain, Ix (the ID parallelized on TIDx) // have extent 32. Then we will generate code like: // if (threadIdx.x < 32) { -// tmem::load +// tcgen05::load // } // For threadIdx.y == 0, it is correct. But for threadIdx.y == 1, it is wrong // because we are using the thread id 33-65 for the load, which is not a warp. @@ -342,7 +342,7 @@ TEST_F(TMemTestCompileOnly, WrongStride) { // map is [TIDy, TIDx] = [2, 33], but in the TMem load/store's loop domain, // we have Iy{1}, Ix{32}. the generated code will be like: // if (threadIdx.x < 32 && threadIdx.y < 1) { -// tmem::load +// tcgen05::load // } // This is valid because we are using a whole warp for the load. TEST_F(TMemTest, InexactParallelType) { diff --git a/tests/python/multidevice/conftest.py b/tests/python/multidevice/conftest.py new file mode 100644 index 00000000000..7bd82c3c6ee --- /dev/null +++ b/tests/python/multidevice/conftest.py @@ -0,0 +1,91 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import nvfuser +import pytest +import torch +import torch.distributed as dist + + +class MultideviceTest: + def __init__(self): + self._communicator = nvfuser.Communicator.instance() + + # This way, when individual tests create unsharded input, each rank + # receives the same data. + torch.manual_seed(0) + + @property + def communicator(self): + return self._communicator + + @property + def size(self): + return self._communicator.size() + + @property + def rank(self): + return self._communicator.rank() + + @property + def local_size(self): + return self._communicator.local_size() + + @property + def local_rank(self): + return self._communicator.local_rank() + + def shard_tensor( + self, t: torch.Tensor, dim: int, mesh: nvfuser.DeviceMesh + ) -> torch.Tensor: + assert t.is_cpu, ( + "This is not strictly required but it's a general good practice " + "for unit tests to create unsharded data on CPU to reduce GPU " + "memory footprint." + ) + return mesh.shard_tensor(t, dim, self.rank).cuda(self.rank) + + +@pytest.fixture +def multidevice_test(): + # Reset the cache here to work around a bug in FusionDefintion.execute. + # FusionDefinition._finalize_definition maps the same `definition` to the + # same FusionSchedules and therefore the same FusionExecutorCache. This was + # correct until multiple FusionDefinitions started to have the same + # `definition` but different `multidevice_schedule`s. This seems to be a + # known issue beacuse a similar workaround for single-GPU schedules is done + # here: + # https://github.com/NVIDIA/Fuser/blob/f44f1913c26f8325099ab6fe46d678cbea435658/tests/python/test_schedule_ops.py#L115. + # + # I couldn't think of an easy way to fix this issue properly. Also, that + # FusionCache is obsolete makes me less motivated to do so. + nvfuser.FusionCache.reset() + + fixture = MultideviceTest() + yield fixture + # Sync all ranks after each test for isolation. + fixture.communicator.barrier() + + +# Set up the default process group for torch APIs like +# dist.device_mesh.init_device_mesh. +# +# This fixture is used by multi-GPU tests that use torch.distributed. +# +# I use "session" instead of "module" because +# https://github.com/pytorch/pytorch/issues/119196 reported race conditions +# when reinitializing process groups. +@pytest.fixture(scope="session") +def setup_default_process_group(): + communicator = nvfuser.Communicator.instance() + + # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. + dist.init_process_group( + backend="nccl", + init_method="tcp://localhost:29500", + world_size=communicator.size(), + rank=communicator.rank(), + ) + yield + dist.destroy_process_group() diff --git a/tests/python/multidevice/fixtures.py b/tests/python/multidevice/fixtures.py deleted file mode 100644 index f14a884fbb9..00000000000 --- a/tests/python/multidevice/fixtures.py +++ /dev/null @@ -1,54 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2024-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import nvfuser -import pytest -import torch - - -class MultideviceTest: - def __init__(self): - self._communicator = nvfuser.Communicator.instance() - - # This way, when individual tests create unsharded input, each rank - # receives the same data. - torch.manual_seed(0) - - @property - def communicator(self): - return self._communicator - - @property - def size(self): - return self._communicator.size() - - @property - def rank(self): - return self._communicator.rank() - - @property - def local_size(self): - return self._communicator.local_size() - - @property - def local_rank(self): - return self._communicator.local_rank() - - def shard_tensor( - self, t: torch.Tensor, dim: int, mesh: nvfuser.DeviceMesh - ) -> torch.Tensor: - assert t.is_cpu, ( - "This is not strictly required but it's a general good practice " - "for unit tests to create unsharded data on CPU to reduce GPU " - "memory footprint." - ) - return mesh.shard_tensor(t, dim, self.rank).cuda(self.rank) - - -@pytest.fixture(scope="session") -def multidevice_test(): - fixture = MultideviceTest() - yield fixture - # Sync all ranks after each test for isolation. - fixture.communicator.barrier() diff --git a/tests/python/multidevice/test_communication.py b/tests/python/multidevice/test_communication.py index 5adf06f6882..5b0d097fe8e 100644 --- a/tests/python/multidevice/test_communication.py +++ b/tests/python/multidevice/test_communication.py @@ -5,14 +5,10 @@ import pytest import torch -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition -multidevice_test = fixtures.multidevice_test - - @pytest.mark.mpi def test_allgather(multidevice_test): d = multidevice_test.size diff --git a/tests/python/multidevice/test_deepseek_v3.py b/tests/python/multidevice/test_deepseek_v3.py new file mode 100644 index 00000000000..284563c398a --- /dev/null +++ b/tests/python/multidevice/test_deepseek_v3.py @@ -0,0 +1,125 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +import pytest +import transformers +import torch +import torch.distributed as dist +from contextlib import contextmanager +from torch.distributed.tensor import DTensor +from torch.distributed.tensor.parallel import ( + parallelize_module, + RowwiseParallel, + ColwiseParallel, +) + + +@contextmanager +def default_tensor_type(dtype=torch.float32, device="cpu"): + # Save + prev_dtype = torch.get_default_dtype() + prev_device = torch.get_default_device() + + # Set + torch.set_default_dtype(dtype) + torch.set_default_device(device) + + yield + + # Restore + torch.set_default_dtype(prev_dtype) + torch.set_default_device(prev_device) + + +# This test timed out once when downloading +# "/deepseek-ai/DeepSeek-V3/resolve/main/configuration_deepseek.py" (cf. +# http://nv/eCm). I consider this a one-off, but please let me know if this +# error becomes consistent. +@pytest.mark.mpi +def test_transformer_layer(setup_default_process_group): + config = transformers.AutoConfig.from_pretrained( + "deepseek-ai/deepseek-v3", trust_remote_code=True + ) + + # Create only one layer which is sufficient for the test. + config.num_hidden_layers = 1 + # Without this, the first and only layer will have a dense MLP instead of MoE. + config.first_k_dense_replace = 0 + # Disable quantization so the test can run on A100 and is made easier for nvFuser. + delattr(config, "quantization_config") + + d = dist.get_world_size() + rank = dist.get_rank() + torch.cuda.set_device(rank) + # This ensures the input tokens are identically replicated on all ranks. + # Otherwise, some ranks may skip an expert because they have no tokens to + # send, while other ranks don't. This will cause a deadlock because a NCCL + # collective is expected to be called by all ranks in the process group. + torch.manual_seed(0) + + mesh = dist.device_mesh.init_device_mesh("cuda", [d]) + + with default_tensor_type(dtype=config.torch_dtype, device="cuda"): + model = transformers.AutoModel.from_config(config, trust_remote_code=True) + # Training is unavailable (cf. https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L439) + model.eval() + + transformer_layer = model.layers[0] + + # By default, RowwiseParallel and ColwiseParallel output a local tensor + # and therefore num_heads needs to be adjusted to accomodate the local + # size. Alternatively, I could RowwiseParallel(use_local_output=False) + # so the linear layer outputs a DTensor, which can be viewed using the + # original num_heads. This requires all activations, parameters, and + # buffers to be DTensor; otherwise aten ops would complain "got mixed + # torch.Tensor and DTensor". Doing so is challenging because + # DeepseekV3RotaryEmbedding creates cos_cached and sin_cached during + # the first forward call (cf. + # https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L143-L144). + transformer_layer.self_attn.num_heads //= d + + # Create the parallel plan + parallel_plan = { + "self_attn.q_b_proj": ColwiseParallel(), + "self_attn.kv_b_proj": ColwiseParallel(), + "self_attn.o_proj": RowwiseParallel(), + } + + for expert in range(config.n_routed_experts): + parallel_plan[f"mlp.experts.{expert}.gate_proj"] = ColwiseParallel() + parallel_plan[f"mlp.experts.{expert}.up_proj"] = ColwiseParallel() + parallel_plan[f"mlp.experts.{expert}.down_proj"] = RowwiseParallel() + + parallel_plan["mlp.shared_experts.gate_proj"] = ColwiseParallel() + parallel_plan["mlp.shared_experts.up_proj"] = ColwiseParallel() + parallel_plan["mlp.shared_experts.down_proj"] = RowwiseParallel() + + transformer_layer = parallelize_module( + transformer_layer, + mesh, + parallel_plan, + ) + + # Sanity-check parameters are indeed distributed + distributed_params: list[str] = [ + name + for name, parameter in transformer_layer.named_parameters() + if isinstance(parameter.data, DTensor) + ] + assert len(distributed_params) == 3 + (config.n_routed_experts + 1) * 3 + + batch_size = 1 + seq_len = 2048 + inp = torch.randn(batch_size, seq_len, config.hidden_size) + mask = transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask( + None, [batch_size, seq_len], inp, past_key_values_length=0 + ) + (out,) = transformer_layer(inp, attention_mask=mask) + # Finish all computation and communication. Otherwise, + # destroy_process_group may deadlock. + torch.cuda.synchronize() + + assert out.size() == (batch_size, seq_len, config.hidden_size) + assert out.dtype == config.torch_dtype + assert out.is_cuda diff --git a/tests/python/multidevice/test_dtensor.py b/tests/python/multidevice/test_dtensor.py index 0cdb52cda27..51644b7bc0d 100644 --- a/tests/python/multidevice/test_dtensor.py +++ b/tests/python/multidevice/test_dtensor.py @@ -2,7 +2,6 @@ # All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import fixtures import nvfuser import pytest import torch @@ -16,22 +15,6 @@ from typing import Callable, cast -multidevice_test = fixtures.multidevice_test - - -@pytest.fixture(scope="module") -def setup_process_group(multidevice_test): - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:29500", - world_size=multidevice_test.size, - rank=multidevice_test.rank, - ) - yield - dist.destroy_process_group() - - class FusionDefinitionWrapper: def __init__(self, define_fusion: Callable[[FusionDefinition], None]): """Wraps a function that defines a fusion without `multidevice_schedule`.""" @@ -94,7 +77,7 @@ def __call__(self, in_dtensors: Iterable[DTensor]) -> list[DTensor]: @pytest.mark.mpi -def test_plus_one(setup_process_group): +def test_plus_one(setup_default_process_group, multidevice_test): def define_fusion(fd: FusionDefinition): inp = fd.define_tensor((-1, -1), contiguity=False, dtype=DataType.Float) one = fd.define_scalar(1.0, dtype=DataType.Float) @@ -118,7 +101,7 @@ def define_fusion(fd: FusionDefinition): @pytest.mark.mpi -def test_linear(setup_process_group): +def test_linear(setup_default_process_group, multidevice_test): @dataclass class LinearConfig: def __init__(self, num_devices: int, batch: int, sequence: int, hidden: int): diff --git a/tests/python/multidevice/test_matmul.py b/tests/python/multidevice/test_matmul.py index c09e4142e26..1c877b4acb5 100644 --- a/tests/python/multidevice/test_matmul.py +++ b/tests/python/multidevice/test_matmul.py @@ -5,12 +5,9 @@ import pytest import torch -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition -multidevice_test = fixtures.multidevice_test - # Avoid doing this when possible. This test started to exist before nvFuser # supports DID loop split. As a result of that, the weight in this test has to be @@ -84,7 +81,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.bias, self.out]: + for t in [self.inp, self.weight, self.bias]: self.sched._set_device_mesh(t, mesh) # Shard N for weight (N, K) and bias (N) @@ -93,12 +90,6 @@ def multidevice_schedule(self): self.sched.parallelize(t, 0, nvfuser.ParallelType.mesh_x) self.sched.set_allocation_as_loop(t) - # Output of linear: {.., i{M}, i{N}, r{K}} - # Shard N -> axis(-2) - self.sched.split(self.out, -2, d, False) - self.sched.parallelize(self.out, -3, nvfuser.ParallelType.mesh_x) - self.sched.set_allocation_as_loop(self.out) - torch.cuda.set_device(multidevice_test.local_rank) b, s = 2, 1024 @@ -138,7 +129,7 @@ def definition(self): self.add_output(self.out) def multidevice_schedule(self): - for t in [self.inp, self.weight, self.out]: + for t in [self.inp, self.weight]: self.sched._set_device_mesh(t, mesh) self.sched.split(t, -1, d, False) self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x) @@ -161,6 +152,52 @@ def multidevice_schedule(self): torch.testing.assert_close(out.cpu(), unsharded_out, rtol=1.3e-6, atol=1e-3) +@pytest.mark.mpi +def test_linear_reduce_scatter(multidevice_test): + d = multidevice_test.size + mesh = nvfuser.DeviceMesh(range(d)) + e = 768 + + class Model(FusionDefinition): + def definition(self): + self.inp = self.define_tensor([-1, -1, d * e]) + self.weight = self.define_tensor([e, d * e]) + self.out = self.ops.linear(self.inp, self.weight, None) + self.add_output(self.out) + + def multidevice_schedule(self): + for t in [self.inp, self.weight, self.out]: + self.sched._set_device_mesh(t, mesh) + self.sched.split(t, -1, d, False) + self.sched.parallelize(t, -2, nvfuser.ParallelType.mesh_x) + self.sched.set_allocation_as_loop(t) + + # Scatter + self.sched.split(self.out, 1, d, False) + self.sched.parallelize(self.out, 1, nvfuser.ParallelType.mesh_x) + + torch.cuda.set_device(multidevice_test.local_rank) + + b, s = 2, 1024 + unsharded_inp = torch.randn(b, s, d * e) + unsharded_weight = torch.randn(e, d * e) + + inp = multidevice_test.shard_tensor(unsharded_inp, -1, mesh) + weight = multidevice_test.shard_tensor(unsharded_weight, -1, mesh) + + fd = Model() + (out,), _ = fd.execute([inp, weight]) + + unsharded_out = torch.nn.functional.linear(unsharded_inp, unsharded_weight, None) + # rtol is the same as the default for fp32. atol is slightly increased. + torch.testing.assert_close( + out, + multidevice_test.shard_tensor(unsharded_out, 1, mesh), + rtol=1.3e-6, + atol=1e-3, + ) + + @pytest.mark.mpi def test_matmul_allreduce(multidevice_test): d, b, s, e = multidevice_test.size, 1, 4, 8 diff --git a/tests/python/multidevice/test_multidevice.py b/tests/python/multidevice/test_multidevice.py index dd3c3fb5877..6489a0a7691 100644 --- a/tests/python/multidevice/test_multidevice.py +++ b/tests/python/multidevice/test_multidevice.py @@ -7,13 +7,10 @@ from enum import Enum, auto from torch.nn.attention import SDPBackend -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition from nvfuser.testing.utils import create_sdpa_rng_tensors, define_sdpa_rng_state -multidevice_test = fixtures.multidevice_test - @pytest.mark.mpi def test_sizes_and_ranks(multidevice_test): diff --git a/tests/python/multidevice/test_overlap.py b/tests/python/multidevice/test_overlap.py index 460a1f0edd8..34850477376 100644 --- a/tests/python/multidevice/test_overlap.py +++ b/tests/python/multidevice/test_overlap.py @@ -6,12 +6,9 @@ import torch import os -import fixtures import nvfuser from nvfuser import DataType, FusionDefinition, CommunicatorBackend -multidevice_test = fixtures.multidevice_test - class OverlapAGMatmulStreamOutermost(FusionDefinition): def __init__(self, m, k, n, s, num_devices, communication_backend): @@ -88,9 +85,6 @@ def test_overlap_allgather_matmul_stream_outermost( ins = [x, weight, bias] out_ref = torch.nn.functional.linear(x_unsharded, weight.cpu(), bias.cpu()) - # Resetting the cache here is necessary to workaround a bug that would need a proper fix. If not avoiding the cache, there is an issue for the second test that is being run. More specifically, the second time we define the fusion, we hit the cache in https://github.com/NVIDIA/Fuser/blob/6ff60e2a320733a2f49de57007d6bb45000107cd/csrc/python_frontend/fusion_definition.cpp#L95 . Later, when we call _set_device_mesh, we get a "thro out of range" here https://github.com/NVIDIA/Fuser/blob/6ff60e2a320733a2f49de57007d6bb45000107cd/csrc/python_frontend/schedule_bindings.cpp#L60 because the FusionDefinition has not run so it doesn't contain any state. - nvfuser.FusionCache.reset() - fd = OverlapAGMatmulStreamOutermost(m, k, n, s, d, backend_type) # warmup diff --git a/tests/python/multidevice/test_transformer_engine.py b/tests/python/multidevice/test_transformer_engine.py index db3046705be..14110e147c6 100644 --- a/tests/python/multidevice/test_transformer_engine.py +++ b/tests/python/multidevice/test_transformer_engine.py @@ -58,19 +58,6 @@ class Parallelism(Enum): SEQUENCE_PARALLEL = auto() -@pytest.fixture(scope="module") -def setup_process_group(mpi_test) -> None: - # The default port as used by https://github.com/pytorch/pytorch/blob/45a8b5682eb69d865cbf68c7f2f689b56b4efd53/torch/csrc/distributed/c10d/TCPStore.hpp#L51. - dist.init_process_group( - backend="nccl", - init_method="tcp://localhost:29500", - world_size=mpi_test.size, - rank=mpi_test.rank, - ) - yield - dist.destroy_process_group() - - # This benchmark is instrumented with cudaProfilerStart/Stop. Therefore, one # can collect stats of the first few non-warmup benchmark iterations using # ```bash @@ -94,7 +81,7 @@ def setup_process_group(mpi_test) -> None: ids=["nonoverlap", "overlap"], ) def test_transformer_layer( - setup_process_group, + setup_default_process_group, monkeypatch, benchmark, compute_type: ComputeType, diff --git a/tests/python/opinfo_input_generators.py b/tests/python/opinfo_input_generators.py index d6653841e6f..199ced4de8b 100644 --- a/tests/python/opinfo_input_generators.py +++ b/tests/python/opinfo_input_generators.py @@ -822,6 +822,25 @@ def index_select_error_generator( # yield SampleInput(a, b, 0), RuntimeError, "out of bounds index value." +def index_put_accumulate_generator( + op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs +): + make_arg = partial( + make_tensor, device="cuda", dtype=dtype, requires_grad=requires_grad + ) + make_index = partial(make_tensor, device="cuda", requires_grad=False) + + # vocab_size, hidden_size, seq_size + cases = ((1024, 12, 300),) + + for vocab, hidden, seq in cases: + for index_dtype in [torch.int, torch.long]: + acc = make_arg((vocab, hidden)) + index = make_index((seq,), low=0, high=vocab, dtype=index_dtype) + value = make_arg((seq, hidden)) + yield SampleInput(acc, index, value) + + def iota_error_generator( op: OpInfo, dtype: torch.dtype, requires_grad: bool = False, **kwargs ): diff --git a/tests/python/opinfos.py b/tests/python/opinfos.py index 14c4f0e8f9f..70fb1c32153 100644 --- a/tests/python/opinfos.py +++ b/tests/python/opinfos.py @@ -29,6 +29,7 @@ gather_generator, index_select_generator, index_select_error_generator, + index_put_accumulate_generator, iota_error_generator, pad_error_generator, permute_generator, @@ -1023,6 +1024,33 @@ def gather_wrapper(fn: callable, input: torch.Tensor, index: torch.Tensor, dim: ) shape_ops.append(index_select_opinfo) + +def index_put_accumulate_ref( + acc: torch.Tensor, index: torch.Tensor, value: torch.Tensor +): + return torch.index_put( + acc, + [ + index, + ], + value, + accumulate=True, + ) + + +index_put_accumulate_opinfo = OpInfo( + lambda fd: fd.ops.index_put_accumulate, + "index_put_accumulate", + sample_input_generator=index_put_accumulate_generator, + reference=index_put_accumulate_ref, + symbolic_parameter_list=( + ArgumentType.Symbolic, + ArgumentType.Symbolic, + ArgumentType.Symbolic, + ), +) +shape_ops.append(index_put_accumulate_opinfo) + # NvFuser's API is significantly different than JAX. # TODO: Change python frontend api to match JAX using a cpp wrapper function. pad_opinfo = OpInfo( diff --git a/tests/python/test_deepseek_v3.py b/tests/python/test_deepseek_v3.py deleted file mode 100644 index 1c9f2acb6f7..00000000000 --- a/tests/python/test_deepseek_v3.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - -import pytest -import transformers -import torch -from contextlib import contextmanager - - -@contextmanager -def default_tensor_type(dtype=torch.float32, device="cpu"): - # Save - prev_dtype = torch.get_default_dtype() - prev_device = torch.get_default_device() - - # Set - torch.set_default_dtype(dtype) - torch.set_default_device(device) - - yield - - # Restore - torch.set_default_dtype(prev_dtype) - torch.set_default_device(prev_device) - - -@pytest.mark.skip(reason="flaky on CI due to download timeout: http://nv/eCm") -def test_transformer_layer(): - config = transformers.AutoConfig.from_pretrained( - "deepseek-ai/deepseek-v3", trust_remote_code=True - ) - - # Create only one layer which is sufficient for the test. - config.num_hidden_layers = 1 - # Without this, the first and only layer will have a dense MLP instead of MoE. - config.first_k_dense_replace = 0 - # Disable quantization so the test can run on A100 and is made easier for nvFuser. - delattr(config, "quantization_config") - - with default_tensor_type(dtype=config.torch_dtype, device="cuda"): - model = transformers.AutoModel.from_config(config, trust_remote_code=True) - # Training is unavailable (cf. https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/modeling_deepseek.py#L439) - model.eval() - - transformer_layer = model.layers[0] - - batch_size = 1 - seq_len = 2048 - inp = torch.randn(batch_size, seq_len, config.hidden_size) - mask = transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask( - None, [batch_size, seq_len], inp, past_key_values_length=0 - ) - (out,) = transformer_layer(inp, attention_mask=mask) - - assert out.size() == (batch_size, seq_len, config.hidden_size) - assert out.dtype == config.torch_dtype - assert out.is_cuda diff --git a/tools/gen_nvfuser_version.py b/tools/gen_nvfuser_version.py deleted file mode 100644 index a09eda53539..00000000000 --- a/tools/gen_nvfuser_version.py +++ /dev/null @@ -1,75 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause -import subprocess -import sys -from pathlib import Path - -UNKNOWN = "Unknown" -nvfuser_root = Path(__file__).parent.parent - - -# note that this root currently is still part of pytorch. -def get_sha() -> str: - try: - return ( - subprocess.check_output(["git", "rev-parse", "HEAD"], cwd=nvfuser_root) - .decode("ascii") - .strip() - ) - except Exception: - import os - - # assume the $NVFUSER_VERSION is in sha form - if nvfuser_version := os.environ.get("NVFUSER_VERSION"): - assert ( - len(nvfuser_version) < 11 - ), "The NVFUSER_VERSION should be in sha form" - return nvfuser_version - return UNKNOWN - - -def get_version() -> str: - sha = get_sha() - version = ( - open((nvfuser_root / "version.txt"), "r").read().strip() + "+git" + sha[:7] - ) - return version - - -def get_pytorch_cmake_prefix(): - from subprocess import Popen, PIPE - - # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch - process_torch_prefix = Popen( - [ - sys.executable, - "-c", - "import torch.utils; print(torch.utils.cmake_prefix_path)", - ], - stdout=PIPE, - ) - stdout_msg, error_msg = process_torch_prefix.communicate() - return stdout_msg.decode("utf-8").rstrip("\n") - - -def get_pytorch_use_distributed(): - from subprocess import Popen, PIPE - - # need to do this in a separate process so we are not going to delete nvfuser library while it's loaded by torch - process_torch_prefix = Popen( - [ - sys.executable, - "-c", - "import torch; print(torch._C._has_distributed())", - ], - stdout=PIPE, - ) - stdout_msg, error_msg = process_torch_prefix.communicate() - return stdout_msg.decode("utf-8").rstrip("\n") - - -if __name__ == "__main__": - version_file = nvfuser_root / "nvfuser" / "version.py" - with open(version_file, "w") as f: - f.write("_version_str = '{}'\n".format(get_version())) diff --git a/tools/gen_nvfuser_version.py b/tools/gen_nvfuser_version.py new file mode 120000 index 00000000000..fef974811db --- /dev/null +++ b/tools/gen_nvfuser_version.py @@ -0,0 +1 @@ +../python/tools/gen_nvfuser_version.py \ No newline at end of file diff --git a/tools/memory.py b/tools/memory.py deleted file mode 100644 index 1ed95f8ded5..00000000000 --- a/tools/memory.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. -# All rights reserved. -# SPDX-License-Identifier: BSD-3-Clause - - -def get_available_memory_gb(): - """Returns the available memory in GB.""" - try: - import psutil - - return psutil.virtual_memory().available / 1024 / 1024 / 1024 - except: # noqa: E722 - pass - - try: - with open("/proc/meminfo", "r") as f: - while True: - line = f.readline() - if line.startswith("MemAvailable:"): - mem = line.split()[1] - assert line.split()[2] == "kB" - return int(mem) / 1024 / 1024 - if not line: - break - except: # noqa: E722 - pass - - return 0 diff --git a/tools/memory.py b/tools/memory.py new file mode 120000 index 00000000000..d818457a563 --- /dev/null +++ b/tools/memory.py @@ -0,0 +1 @@ +../python/tools/memory.py \ No newline at end of file