Skip to content

[Host irs] Stream lowering of single device fusions#4147

Merged
samnordmann merged 72 commits intohost_irs/LoadStore_Reduction_binaryOp_supportfrom
host_irs/stream_lowering/single_device_fusions
Apr 28, 2025
Merged

[Host irs] Stream lowering of single device fusions#4147
samnordmann merged 72 commits intohost_irs/LoadStore_Reduction_binaryOp_supportfrom
host_irs/stream_lowering/single_device_fusions

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Mar 26, 2025

This PR belongs to a series of stacked PRs:

  1. [Host irs] alias and preallocated output support #4144
  2. [Host Ir] refactor and cleanup lowering and segmentation #4145
  3. [Host ir] support for set reduce and binary op #4146
  4. add HirAliasSelect #4301
  5. => You are here: [Host irs] Stream lowering of single device fusions #4147

What

Implement a proper lowering for handling ParallelType::Stream. This PR has the following restrictions:

  • Single device fusion
  • No split/merge of Stream axis

We add to Hir lowering a new pass that reads the hir container's top level expressions, reads the consumer's stream parallelization and create For Loop with stream management and sync for expressing the stream parallelization. Basic logic for merging For-Loop are written.

Let me explain through some examples that can be found in the PR. We suggest to run those examples as follows:

NVFUSER_DUMP=host_ir test_host_ir --gtest_filter=*

Single expr and for-loop

Look at MultiDeviceExecutorLowerStreamTest.SingleSetOp simple scenario:

  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

the dumped generated Host Ir program is:

%HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}]) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}]) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  FOR StreamIdx in iStreamIdx2{i0}:
    GetCurrentStream into Stream 0
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
    T2_l_float[iS4{i2}]
       = HirAliasSelect( T0_g_float[iS0{i0}, iS1{i2}], axis = iS0{i0}, index = StreamIdx )
    T3_l_float[iS5{i2}]
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}], axis = iStreamIdx2{i0}, index = StreamIdx )
    T3_l_float[iS5{i2}]
       = Set( T2_l_float[iS4{i2}], cache_op=Streaming )
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer

We can see that the expr, here the "Set", gets embedded into a For Loop. Let us analyze further:

  • outside the for loop, we allocate the global output buffer.
  • The start of the for loop body does the new stream assignment and sync of that stream to the user stream
  • Then, we "Select" (aka slice) through HirAliasSelect into the input and output
  • The "Set" operation is executed on the "selected" I/O. Note that the output is an alias to the output's slice.
  • At the end of the for loop, we reset to the user's stream (I mean, the currently selected stream before entering the program) and sync the user's stream with the running stream.

Merging for loops

To avoid unnecessary synchronization across streams, it is important to be able to fuse the stream for-loop. This is exercised by the test MultiDeviceExecutorLowerStreamTest.TwoSetOps:

  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  fusion->addInput(tv0);
  fusion->addOutput(tv2);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv2->axis(0)->parallelize(ParallelType::Stream);

dump:

%HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}]) -> (T2_g_float[iStreamIdx4{i0}, iS5{i2}]) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx4{i0}, iS5{i2}] = ALLOCATE(buffer=T2_g_float[iStreamIdx4{i0}, iS5{i2}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  FOR StreamIdx in iStreamIdx2{i0}:
    GetCurrentStream into Stream 0
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
    T3_l_float[iS6{i2}]
       = HirAliasSelect( T0_g_float[iS0{i0}, iS1{i2}], axis = iS0{i0}, index = StreamIdx )
    T4_l_float[iS7{i2}]
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}], axis = iStreamIdx2{i0}, index = StreamIdx )
    T4_l_float[iS7{i2}]
       = Set( T3_l_float[iS6{i2}], cache_op=Streaming )
    T5_l_float[iS8{i2}]
       = HirAliasSelect( T2_g_float[iStreamIdx4{i0}, iS5{i2}], axis = iStreamIdx4{i0}, index = StreamIdx )
    T5_l_float[iS8{i2}]
       = Set( T4_l_float[iS7{i2}], cache_op=Streaming )
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer

We observe that the For-loop are indeed merged.
Possible future optimization: the allocation of the intermediate buffer could be only of length numberOfStreams

separating for loops

We also need to be able to separate and create new for loops if necessary, as exercised in ThreeSetOpsWithDisjointsForLoops, which considers the Fusion:

  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = set(tv2);
  fusion->addInput(tv0);
  fusion->addOutput(tv3);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv3->axis(0)->parallelize(ParallelType::Stream);

Here, tv2 is not stream-parallelized so it should be be produced in a for-loop. Dump:

%HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}]) -> (T3_g_float[iStreamIdx6{i0}, iS7{i2}]) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  FOR StreamIdx in iStreamIdx2{i0}:
    GetCurrentStream into Stream 0
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
    T4_l_float[iS8{i2}]
       = HirAliasSelect( T0_g_float[iS0{i0}, iS1{i2}], axis = iS0{i0}, index = StreamIdx )
    T5_l_float[iS9{i2}]
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}], axis = iStreamIdx2{i0}, index = StreamIdx )
    T5_l_float[iS9{i2}]
       = Set( T4_l_float[iS8{i2}], cache_op=Streaming )
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
  T2_g_float[iS4{i0}, iS5{i2}]
     = Set( T1_g_float[iStreamIdx2{i0}, iS3{i2}], cache_op=Streaming )
  T3_g_float[iStreamIdx6{i0}, iS7{i2}] = ALLOCATE(buffer=T3_g_float[iStreamIdx6{i0}, iS7{i2}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  FOR StreamIdx in iStreamIdx6{i0}:
    GetCurrentStream into Stream 2
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 2
    T6_l_float[iS10{i2}]
       = HirAliasSelect( T2_g_float[iS4{i0}, iS5{i2}], axis = iS4{i0}, index = StreamIdx )
    T7_l_float[iS11{i2}]
       = HirAliasSelect( T3_g_float[iStreamIdx6{i0}, iS7{i2}], axis = iStreamIdx6{i0}, index = StreamIdx )
    T7_l_float[iS11{i2}]
       = Set( T6_l_float[iS10{i2}], cache_op=Streaming )
    SetCurrentStream to Stream 2
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer

@samnordmann samnordmann changed the base branch from main to host_irs/LoadStore_Reduction_binaryOp_support March 26, 2025 13:35
@github-actions
Copy link

github-actions bot commented Mar 26, 2025

Review updated until commit 35ff4da

Description

  • Implement StreamParallelType pass for Host IR

  • Add tests for Host IR stream lowering

  • Fix swizzled_tiles calculation in fillDefaultHopperHeuristic


Changes walkthrough 📝

Relevant files
Refactoring
1 files
normalization_inner_outer.cpp
Refactor and modularize inner-outer reduction scheduling 
+21/-1604
Enhancement
11 files
fusion_segmenter.cpp
Add edge management and auxiliary group handling                 
+226/-246
kernel_ir.cpp
Add Continue expression and update utility names                 
+23/-2   
translation.cpp
Add FusionTranslator for translating CPP Fusion to FusionDefinition
[link]   
stream_parallel_type.cpp
Implement StreamParallelType pass for Host IR                       
+440/-0 
allocation.cpp
Update AllocationInserter to handle MmaOp with Blackwell 
+7/-5     
utils.cpp
Add IndexPutAccumulateOp to isTvOp check                                 
+1/-15   
fusion_cache.cpp
Add FusionCache implementation for serialization and deserialization
[link]   
normalization_inner_outer_multi_wave.cpp
Implement inner-outer multi-wave reduction scheduling       
+717/-0 
normalization_inner_outer_tma_ws.cpp
Implement inner-outer TMA warp specialized reduction scheduling
+647/-0 
vectorize_helper.cpp
Add logic to handle vectorization through Resize expressions
+85/-64 
utils.cpp
Update sharding logic to handle reduction IterDomains       
+15/-15 
Configuration changes
2 files
gen_nvfuser_version.py
Link to external version generation script                             
+0/-75   
memory.py
Link memory.py to the correct path                                             
+0/-28   
Tests
1 files
test_host_ir_stream_lowering.cpp
Add tests for Host IR stream lowering                                       
+814/-0 
Bug fix
1 files
matmul_utils.cpp
Fix swizzled_tiles calculation in fillDefaultHopperHeuristic
+1/-1     
Additional files
101 files
build.yml +1/-0     
lint.yml +6/-0     
.lintrunner.toml +2/-2     
CMakeLists.txt +29/-20 
README.md +10/-0   
core.py +12/-85 
test_cross_entropy_loss.py +2/-2     
test_matmul.py +6/-0     
alias_analysis.cpp +0/-4     
codegen.cpp +6/-2     
predicate_elimination.cpp +1/-1     
lower2device.h +6/-6     
circular_buffer.cpp +326/-152
index.cpp +14/-1   
index.h +1/-0     
inline_ptx.cpp +1/-1     
insert_syncs.cpp +59/-37 
utils.h +0/-10   
dispatch.h +4/-1     
fusion_segmenter.h +20/-15 
container.cpp +6/-4     
executor.cpp +45/-8   
executor.h +5/-0     
host_ir.cpp +45/-0   
host_ir.h +43/-0   
lower.cpp +10/-0   
stream_parallel_type.h +36/-0   
interface_nodes.h +2/-2     
internal_nodes.h +60/-0   
iostream.cpp +2/-7     
nodes.cpp +47/-0   
kernel_ir.h +18/-1   
logical_domain_map.cpp +33/-16 
c10d_mock.h +29/-1   
executor.h +4/-0     
ipc_handle.cpp +6/-4     
utils.h +9/-2     
indexing.cpp +38/-0   
indexing.h +6/-0     
options.cpp +10/-17 
options.h +22/-2   
parallel_dimension_map.h +2/-2     
predicate_compute.cpp +13/-13 
insert_reshardings.cpp +55/-4   
make_resharding_contiguous.cpp +122/-17
make_resharding_contiguous.h +11/-4   
optimization_pass.h +0/-2     
pre_segmenter.cpp +10/-6   
propagate_shardings.cpp +255/-76
compiled_kernel.cpp +7/-0     
executor_utils.h +2/-0     
fusion_kernel_runtime.cpp +1/-52   
expr_eval_sched.cpp +7/-1     
hopper_multi_matmul.cpp +9/-44   
hopper_multi_matmul.h +0/-8     
multi_matmul.cpp +54/-0   
multi_matmul.h +14/-0   
normalization_inner_outer_multi_wave.h +31/-0   
normalization_inner_outer_tma_ws.h +31/-0   
normalization_inner_outer_utils.cpp +301/-0 
normalization_inner_outer_utils.h +98/-0   
normalization_utils.cpp +23/-5   
reduction.cpp +13/-0   
reduction_utils.cpp +42/-3   
reduction_utils.h +3/-1     
registry.cpp +6/-1     
domain_map.cpp +12/-9   
domain_map.h +1/-1     
fusion_cache.fbs +1/-0     
fusion_record.cpp +7/-0     
type.cpp +3/-3     
type.h +3/-3     
nvfuser +1/-0     
LICENSE +1/-0     
README.md [link]   
__init__.py [link]   
__init__.pyi [link]   
benchmark_utils.py +105/-0 
__init__.py [link]   
__init__.py [link]   
normalization.py [link]   
nvfuser_version.py [link]   
pytorch_utils.py [link]   
__init__.py [link]   
utils.py [link]   
utils.py [link]   
pyproject.toml +3/-0     
distributed_tensor.cpp [link]   
distributed_tensor.h [link]   
fusion_cache.h [link]   
fusion_definition.cpp +5/-0     
fusion_definition.h [link]   
fusion_record.h +24/-0   
fusion_state.cpp [link]   
fusion_state.h [link]   
multidevice_bindings.cpp [link]   
python_bindings.cpp +42/-0   
python_bindings.h [link]   
python_bindings_extension.cpp [link]   
schedule_bindings.cpp [link]   
Additional files not shown

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 No relevant tests
⚡ Recommended focus areas for review

Code Duplication

The code for innerOuterWarpSpecializedTmaHeuristic and scheduleTmaWarpSpecializedInnerOuter seems to be very similar to the code for innerOuterPersistentHeuristic and scheduleInnerOuterPersistentKernel. Consider refactoring to reduce duplication.

    hp.threads_per_block_min,
    hp.threads_per_block_max);

auto rparams = std::make_unique<ReductionParams>(
    InnerOuterPersistentKernelScheduler::schedulerType());
// Ultimately, we want the heuristic to decide between using the
// warp-specialized version or the multi-wave version. The enable option is a
// temporary configuration to facilitate testing during development without
// disrupting existing behavior.
if (isOptionEnabled(EnableOption::WarpSpecializedNormalization)) {
  inner_outer_tma_warp_specialized::getHeuristics(
      rparams.get(),
      properties.total_iteration_numel,
      properties.total_reduction_numel,
      buffer_params.regs_buffer_size,

Removed Functionality
Several functions such as scheduleReductionCombinedOuter, scheduleInnerOuterPersistentKernel, scheduleTmaWarpSpecializedOuter, and scheduleTmaWarpSpecializedInnerOuter have been removed. Ensure that the new code in inner_outer_multi_wave.cpp and inner_outer_tma_warp_specialized.cpp covers all the functionality of the removed functions.

New File

The new file python/python_frontend/schedule_bindings.cpp introduces a significant amount of new code. Ensure that all new bindings are correctly implemented and tested.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on
#include <instrumentation.h>
#include <ir/interface_nodes.h>
#include <multidevice/utils.h>
#include <python_frontend/fusion_cache.h>
#include <python_frontend/fusion_definition.h>
#include <python_frontend/python_bindings.h>
#include <scheduler/tools/inlining.h>
#include <transform_replay.h>

namespace nvfuser::python_frontend {

void bindSchedule(py::class_<FusionDefinition>& fusion_def) {
  //! The SchedOperators class is a nested class of FusionDefinition to allow
  //! the user to query the class for the list of schedule operators.
  //!
  //! Example:
  //!   help(FusionDefinition.SchedOperators)
  //!
  //! Additional operators are expected to be defined below as needed.
  py::class_<FusionDefinition::SchedOperators> nvf_sched(
      fusion_def, "SchedOperators");
  nvf_sched.def(py::init<FusionDefinition*>());
  nvf_sched.def(
      "to_string",
      [](FusionDefinition::SchedOperators& self, Tensor tensor) {
        // NOTE: For debugging purposes, print the state of TensorView
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        // Determine if tensor is a result from a reduction operation.
        FusionDefinition* fd = self.fusion_definition;
        TensorView* tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        return tv->toString();
      },
      py::arg("tensor"));
  nvf_sched.def(
      "user_schedule_ir",
      [](FusionDefinition::SchedOperators& self) {
        return self.fusion_definition->userScheduleIr();
      },
      py::return_value_policy::reference);
  //! experimental API for multidevice support
  nvf_sched.def(
      "_set_device_mesh",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         const DeviceMesh& mesh) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto tv = fd->getFusionState(tensor.index)->template as<TensorView>();
        tv->setDeviceMesh(mesh);
      },
      py::arg("tensor"),
      py::arg("mesh"));
  nvf_sched.def(
      "parallelize",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         int axis,
         const ParallelType& parallel_type) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto tv = fd->getFusionState(tensor.index)->template as<TensorView>();
        tv->axis(axis)->parallelize(parallel_type);
      },
      py::arg("tensor"),
      py::arg("axis"),
      py::arg("parallel_type"));
  nvf_sched.def(
      "merge",
      [](FusionDefinition::SchedOperators& self, Tensor arg, int dim) {
        FUSER_PERF_SCOPE("SchedOperators.merge");
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto input_tv =
            fd->getFusionState(arg.index)->template as<TensorView>();
        input_tv->merge(dim);
      },
      py::arg("arg"),
      py::arg("dim"));
  auto reduction_factor_func = [](FusionDefinition::SchedOperators& self,
                                  Tensor arg,
                                  const std::vector<int64_t>& dims) -> Tensor {
    FUSER_PERF_SCOPE("SchedOperators.reduction_factor");
    NVF_CHECK(
        self.validUse(),
        "Attempting to use a SchedOperators Op prior to definition!");
    FusionDefinition* fd = self.fusion_definition;
    TensorView* input_tv =
        fd->getFusionState(arg.index)->template as<TensorView>();
    TensorView* output_tv = input_tv->rFactor(dims);
    return fd->addTensor(output_tv);
  };
  nvf_sched.def(
      "reduction_factor",
      reduction_factor_func,
      py::arg("arg"),
      py::arg("dims"));
  nvf_sched.def(
      "rfactor", reduction_factor_func, py::arg("arg"), py::arg("dims"));
  nvf_sched.def(
      "reorder",
      [](FusionDefinition::SchedOperators& self,
         Tensor arg,
         const std::unordered_map<int64_t, int64_t>& old2new) {
        FUSER_PERF_SCOPE("SchedOperators.reorder");
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto input_tv =
            fd->getFusionState(arg.index)->template as<TensorView>();
        input_tv->reorder(old2new);
      },
      py::arg("arg"),
      py::arg("old2new"));
  nvf_sched.def(
      "split",
      [](FusionDefinition::SchedOperators& self,
         Tensor arg,
         int64_t dim,
         int64_t factor,
         bool inner_split) {
        FUSER_PERF_SCOPE("SchedOperators.split");
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto input_tv =
            fd->getFusionState(arg.index)->template as<TensorView>();
        input_tv->split(dim, factor, inner_split);
      },
      py::arg("arg"),
      py::arg("dim"),
      py::arg("factor"),
      py::arg("inner_split") = true);
  nvf_sched.def(
      "set_allocation_as_loop",
      [](FusionDefinition::SchedOperators& self, Tensor arg) {
        FUSER_PERF_SCOPE("SchedOperators.set_allocation_as_loop");
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        auto* tv = fd->getFusionState(arg.index)->template as<TensorView>();
        tv->setAllocationDomain(tv->getLoopDomain(), true);
      },
      py::arg("arg"));
  nvf_sched.def(
      "cache_after",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         const LoadStoreOpType& op_type,
         const CacheOp& cache_op) -> Tensor {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        TensorView* input_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        TensorView* output_tv = input_tv->cacheAfter(op_type, cache_op);
        return fd->addTensor(output_tv);
      },
      py::arg("tensor"),
      py::arg("op_type") = LoadStoreOpType::Set,
      py::arg("cache_op") = CacheOp::Unspecified);
  nvf_sched.def(
      "cache_before",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         const LoadStoreOpType& op_type) -> Tensor {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        TensorView* input_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        TensorView* output_tv = input_tv->cacheBefore(op_type);
        return fd->addTensor(output_tv);
      },
      py::arg("tensor"),
      py::arg("op_type") = LoadStoreOpType::Set);
  nvf_sched.def(
      "cache_fork",
      [](FusionDefinition::SchedOperators& self, Tensor tensor) -> Tensor {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        TensorView* input_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        TensorView* output_tv = input_tv->cacheFork();
        return fd->addTensor(output_tv);
      },
      py::arg("tensor"));
  nvf_sched.def(
      "set_memory_type",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         const MemoryType& memory_type) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        FusionDefinition* fd = self.fusion_definition;
        TensorView* tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        tv->setMemoryType(memory_type);
      },
      py::arg("tensor"),
      py::arg("memory_type"));
  nvf_sched.def(
      "transform_like",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         const std::vector<Tensor>& selected_tensors) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");

        FusionDefinition* fd = self.fusion_definition;
        TensorView* reference_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();

        TransformPropagator propagator(reference_tv);
        if (selected_tensors.empty()) {
          // Propagate scheduler transformations on reference TensorView to the
          // rest of the fusion.
          MaxLogicalDomainInfoSpanningTree(reference_tv).traverse(&propagator);
        } else {
          // Propagate scheduler transformations on reference TensorView to the
          // subset of the fusion.
          std::unordered_set<TensorView*> selected_tv_set;
          selected_tv_set.reserve(selected_tensors.size());
          std::transform(
              selected_tensors.begin(),
              selected_tensors.end(),
              std::inserter(selected_tv_set, selected_tv_set.end()),
              [&fd](const Tensor& t) {
                return fd->getFusionState(t.index)->template as<TensorView>();
              });
          SetSelector selector(
              {selected_tv_set.begin(), selected_tv_set.end()});
          MaxLogicalDomainInfoSpanningTree(reference_tv, &selector)
              .traverse(&propagator);
        }
      },
      py::arg("tensor"),
      py::arg("selected_tensors") = std::vector<Tensor>());
  nvf_sched.def(
      "parallelize_like",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         int64_t pos,
         const std::vector<Tensor>& selected_tensors,
         const std::unordered_set<ParallelType>& selected_parallel_types,
         bool propagate_padding) {
        // Propagate the parallelization from the selected dimensions of the
        // reference tensor to their corresponding dimensions in all selected
        // tensors in the DAG.
        //
        // 1. Position `pos` means selecting all the dimensions
        // [0, 1, ..., pos - 1]. pos = -1 means selecting all dimensions.
        // 2. `selected_tvs` are selected tensors in the DAG. Empty
        // `selected_tvs` means selecting all tensors in the fusion of
        // `reference_tv`.
        // 3. `selected_parallel_types` are the selected parallel types. Empty
        // `selected_parallel_types` means selecting all parallel types.

        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");

        FusionDefinition* fd = self.fusion_definition;
        TensorView* reference_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();

        std::vector<TensorView*> selected_tvs;
        selected_tvs.reserve(selected_tensors.size());
        std::transform(
            selected_tensors.begin(),
            selected_tensors.end(),
            std::back_inserter(selected_tvs),
            [&fd](const Tensor& t) {
              return fd->getFusionState(t.index)->template as<TensorView>();
            });

        nvfuser::scheduler_utils::parallelizeAllLike(
            reference_tv,
            pos,
            selected_tvs,
            selected_parallel_types,
            propagate_padding);
      },
      py::arg("tensor"),
      py::arg("pos") = -1,
      py::arg("selected_tensors") = std::vector<Tensor>(),
      py::arg("selected_parallel_types") = std::unordered_set<ParallelType>(),
      py::arg("propagate_padding") = true);
  nvf_sched.def(
      "inline_most",
      [](FusionDefinition::SchedOperators& self,
         const std::vector<Tensor>& selected_tensors) {
        // Inline to the right most allowed position for the selected tensors in
        // the current fusion.

        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");

        FusionDefinition* fd = self.fusion_definition;

        if (selected_tensors.empty()) {
          nvfuser::inlineMost();
        } else {
          std::vector<TensorView*> selected_tvs;
          selected_tvs.reserve(selected_tensors.size());
          std::transform(
              selected_tensors.begin(),
              selected_tensors.end(),
              std::back_inserter(selected_tvs),
              [&fd](const Tensor& t) {
                return fd->getFusionState(t.index)->template as<TensorView>();
              });
          nvfuser::inlineMost(selected_tvs);
        }
      },
      py::arg("selected_tensors") = std::vector<Tensor>());
  nvf_sched.def(
      "inline_at",
      [](FusionDefinition::SchedOperators& self,
         Tensor tensor,
         int64_t pos,
         bool best_effort,
         const std::vector<Tensor>& selected_tensors) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");

        FusionDefinition* fd = self.fusion_definition;
        TensorView* reference_tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();

        if (selected_tensors.empty()) {
          // Inline to the position corresponding to the reference position in
          // the reference tensor for all tensors in the current fusion.
          nvfuser::inlineAllAt(reference_tv, pos, best_effort);
        } else {
          // Inline to the position corresponding to the reference position in
          // the reference tensor for selected tensors in the current fusion.
          std::unordered_set<TensorView*> selected_tvs;
          selected_tvs.reserve(selected_tensors.size());
          std::transform(
              selected_tensors.begin(),
              selected_tensors.end(),
              std::inserter(selected_tvs, selected_tvs.end()),
              [&fd](const Tensor& t) {
                return fd->getFusionState(t.index)->template as<TensorView>();
              });

          nvfuser::inlineSelectedAt(
              selected_tvs, reference_tv, pos, best_effort);
        }
      },
      py::arg("tensor"),
      py::arg("pos") = -1,
      py::arg("best_effort") = false,
      py::arg("selected_tensors") = std::vector<Tensor>());
  nvf_sched.def("tensors", [](FusionDefinition::SchedOperators& self) {
    NVF_CHECK(
        self.validUse(),
        "Attempting to use a SchedOperators Op prior to definition!");
    // Return all Tensors in FusionDefinition
    return self.fusion_definition->tensors();
  });
  nvf_sched.def(
      "is_reduction",
      [](FusionDefinition::SchedOperators& self, Tensor tensor) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        // Determine if tensor is a result from a reduction operation.
        FusionDefinition* fd = self.fusion_definition;
        TensorView* tv =
            fd->getFusionState(tensor.index)->template as<TensorView>();
        return (
            !tv->isFusionInput() &&
            std::any_of(
                tv->getMaybeRootDomain().begin(),
                tv->getMaybeRootDomain().end(),
                [](IterDomain* id) { return id->isReduction(); }) &&
            !isResharding(tv->definition()));
      },
      py::arg("tensor"));
  nvf_sched.def(
      "can_schedule",
      [](FusionDefinition::SchedOperators& self,
         const SchedulerType& scheduler_type) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        return self.fusion_definition->userSchedule()->canScheduleDebug(
            scheduler_type);
      },
      py::arg("scheduler_type"));
  nvf_sched.def(
      "find_compatible_schedulers", [](FusionDefinition::SchedOperators& self) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");

        std::vector<SchedulerType> valid_scheduler_types;
        valid_scheduler_types.reserve(all_heuristics_in_priority_order.size());
        std::copy_if(
            all_heuristics_in_priority_order.begin(),
            all_heuristics_in_priority_order.end(),
            std::back_inserter(valid_scheduler_types),
            [sched = self.fusion_definition->userSchedule()](
                SchedulerType scheduler_type) {
              return sched->canSchedule(scheduler_type);
            });
        return valid_scheduler_types;
      });
  nvf_sched.def(
      "schedule",
      [](FusionDefinition::SchedOperators& self,
         const SchedulerType& scheduler_type) {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        UserSchedule* sched = self.fusion_definition->userSchedule();
        auto&& [can_schedule, error_msg] =
            sched->canScheduleDebug(scheduler_type);
        NVF_CHECK(can_schedule, error_msg);
        sched->scheduleWithType(scheduler_type);
      },
      py::arg("heuristic"));
  nvf_sched.def("schedule", [](FusionDefinition::SchedOperators& self) {
    NVF_CHECK(
        self.validUse(),
        "Attempting to use a SchedOperators Op prior to definition!");
    UserSchedule* sched = self.fusion_definition->userSchedule();
    sched->schedule();
  });
  nvf_sched.def(
      "compute_pointwise_heuristics",
      [](FusionDefinition::SchedOperators& self) -> PointwiseParams& {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        UserSchedule* sched = self.fusion_definition->userSchedule();
        HeuristicParams* parameters =
            sched->computeHeuristics(SchedulerType::PointWise);
        return *parameters->as<PointwiseParams>();
      },
      py::return_value_policy::reference);
  nvf_sched.def(
      "compute_reduction_heuristics",
      [](FusionDefinition::SchedOperators& self) -> ReductionParams& {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        UserSchedule* sched = self.fusion_definition->userSchedule();
        HeuristicParams* parameters =
            sched->computeHeuristics(SchedulerType::Reduction);
        return *parameters->as<ReductionParams>();
      },
      py::return_value_policy::reference);
  nvf_sched.def(
      "compute_matmul_heuristics",
      [](FusionDefinition::SchedOperators& self) -> MatmulParams& {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        UserSchedule* sched = self.fusion_definition->userSchedule();
        HeuristicParams* parameters =
            sched->computeHeuristics(SchedulerType::Matmul);
        return *parameters->as<MatmulParams>();
      },
      py::return_value_policy::reference);
  nvf_sched.def(
      "schedule_hyperparameters",
      [](FusionDefinition::SchedOperators& self)
          -> scheduler_utils::SchedulerHyperParameters& {
        NVF_CHECK(
            self.validUse(),
            "Attempting to use a SchedOperators Op prior to definition!");
        UserSchedule* sched = self.fusion_definition->userSchedule();
        auto scheduler_hyperparameters_entry = HeuristicDataCacheEntry<
            HeuristicCompileTime::SchedulerHyperParameters>(
            sched->data_cache.get(), []() {
              return std::make_unique<
                  scheduler_utils::SchedulerHyperParameters>(
                  /*vectorize_factor=*/1,
                  /*unroll_factor=*/1,
                  /*threads_per_block_min=*/1,
                  /*threads_per_block_max=*/1);
            });
        return scheduler_hyperparameters_entry.get();
      },
      py::return_value_policy::reference);
}

} // namespace nvfuser::python_frontend

@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from 1f717f4 to 282687a Compare March 26, 2025 13:38

using HirLowerStreamTest = NVFuserTest;

TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) {
Copy link
Collaborator Author

@samnordmann samnordmann Mar 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest to start the review by reading the PR description and inspecting those tests, running them with the option NVFUSER_DUMP=host_ir.

There are two identical sets of tests: HirLowerStreamTest and MultiDeviceExecutorLowerStreamTest. Both sets exercise semantically identical situations, but the former is "lower level" in the sense that it manually builds a Hir container and apply the pass, while the latter goes through the main level API for defining a fusion, which then gets automatically lowered and executed through MultiDeviceExecutor

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested review from nsarka and wujingyue March 26, 2025 15:19
TensorView* tv,
int64_t dim,
Val* index,
bool keep_reduction_axis) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand the logic behind removing the reduction axis when selecting: since it is an aliasing op, it does not produce the reduction so should keep it. Anyway, this is problematic in our context where we need the reduction axis since the selected tensor will actually become the reduced tensor. So, I decided to add this option to control this behavior

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, a reduction dimension always stops at the op that performs the reduction. So the original code looks right and I don't follow the motivation of this change. Do you have an example where we have to keep the reduction dimension in a downstream op?

cc @naoyam in case I missed anything

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, a reduction dimension always stops at the op that performs the reduction. So the original code looks right and I don't follow the motivation of this change.

Yes, but indeed, a select op do not performs the reduction, this is why is makes sense to keep the reduction axis.

Do you have an example where we have to keep the reduction dimension in a downstream op?

Consider the case of MultiDeviceExecutorLowerStreamTest.Reduction, that I am reproducing here:

auto tv0 = makeContigTensor(3); // [i0, i1, i2]
auto tv1 = sum(tv0, {2}); // [Stream(i0), i1, r2]

tv1->axis(0)->parallelize(ParallelType::Stream);

The stream lowering pass implements indexing into tv0 and tv1 by creating SelectOps, and letting the reduction operate on the selected tensors. The dumped Host program looks like:

%HostIrContainer { (T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}]) -> (T1_g_float[iStream3{i0}, iS4{i2}, rS5{i3}]) :
  T1_g_float[iStream3{i0}, iS4{i2}, rS5{i3}] = ALLOCATE(buffer=T1_g_float[iStream3{i0}, iS4{i2}, rS5{i3}], mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  FOR i10 in iStream3{i0}:
    GetCurrentStream into Stream 0
    SetCurrentStream to Stream ( i10 % numberOfStreams )
    Synchronize Stream 0
    T2_l_float[iS6{i2}, iS7{i3}]
       = select( T0_g_float[iS0{i0}, iS1{i2}, iS2{i3}], axis = iS0{i0}, index = i10 )
    T3_l_float[iS8{i2}, rS9{i3}]
       = select( T1_g_float[iStream3{i0}, iS4{i2}, rS5{i3}], axis = iStream3{i0}, index = i10 )
    T3_l_float[iS8{i2}, rS9{i3}]
       = reduction( T2_l_float[iS6{i2}, iS7{i3}], op = add, initial value = float(0), allreduce = false )
    SetCurrentStream to Stream 0
    Synchronize Stream ( i10 % numberOfStreams )
} // %HostIrContainer

The sum now operates not on tv0 and tv1, but on alias slices of it. This is why we need to keep the reduction axis on tv1's SelectOp's consumer.

I understand I am using SelectOp here to index, while it was originally thought of as a tensor's semantic op given in the user's DAG. But Hir indexing has other needs than kernel indexing: I cannot directly index into C-style arrays, I need to use at::select ops (or other ATen ops), to then feed the indexed tensor into other ATen ops (such as at::matmul, since we do not use matmul generated kernel for now). This is why I need to extend the usage of SelectOp.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, if this raises too much concern, I can create another HIR op to avoid collisions with SelectOp

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, a reduction dimension always stops at the op that performs the reduction.

That's right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, but indeed, a select op do not performs the reduction, this is why is makes sense to keep the reduction axis.

It isn't about this op, but it's the producer tensor that has reduction IDs. They are there because the op that produces the produce is a reduction. Reduction IDs are not be kept around after the reduction op itself.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can create another HIR op to avoid collisions with SelectOp

SGTM to unblock your experiment. Thanks for the example! IIUC, both the second select and the reduction define T3, causing a mismatch in r. Arguably, select isn't designed for this use case because it belongs to fusion IR that's supposed to be SSA. So creating a new Host IR for this use case makes sense to me.

Copy link
Collaborator Author

@samnordmann samnordmann Apr 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did as you suggest and created a dedicated HIR in #4301

It actually solves other problems along the way, and allowed me to

  1. get rid of aliases in hic: indeed, I do not set the indexed tensor's definition when indexing, thus avoiding cycling dependency. HIC's alias is therefore not necessary anymore and I'll remove it in a separate PR
  2. remove the bool keep_reduction_axis option I added in select, newValLike, etc.

So the code is cleaner.

@samnordmann samnordmann force-pushed the host_irs/LoadStore_Reduction_binaryOp_support branch from 10daa92 to 85b7b75 Compare April 11, 2025 13:30
samnordmann added a commit that referenced this pull request Apr 11, 2025
# Context

## Previous bug fix in comms/compute overlap

A recent PR #3913 fixed a bug in the
comms/compute overlap algorithm, consisting of adding a stream
synchronization when entering the host for-loop. Namely, currently, the
host generated program reads
```
$ mpirun -x NVFUSER_DUMP=host_ir -np 2 python -m pytest tests/python/multidevice/test_overlap.py --only-mpi
```
```
%HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
  GetCurrentStream into Stream 0
  T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    Synchronize Stream 0
    T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
       = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
    T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
       = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
    Communication 35 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 35
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
          T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
    SetCurrentStream to Stream 0
    Synchronize Stream ( i83 % numberOfStreams )
} // %HostIrContainer
```
Note the line `Synchronize Stream 0` at the beginning of the for-loop


## Degradation of performance

However, even though it has not been verified before merging, PR
#3913 degraded performances of the
overlapped algo. Basically, since PR #3913, we do not observe
overlapping anymore, even when using UCC/TL/NCCL:
<img width="1224" alt="Screenshot 2025-04-10 at 15 30 32"
src="https://github.com/user-attachments/assets/e6e4fbcd-f5de-47b6-a05d-403dbc06ca74"
/>


# Performance fix in the present PR

## What
The current PR fixes this performance degradation. The idea has been
found incenditally and reveals some probably interesting finding.

What needs to be done is to execute all the stream synchronization in a
separate host for-loop, before the host for-loop responsible for the
comms/compute. The new generated program reads:
```
%HostIrContainer { (T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1}), T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})) -> (T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1})) :
  GetCurrentStream into Stream 0
  T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}) = ALLOCATE(buffer=T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), mem_type=global, size=1048576, zero_init=false, resets_to_zero=false)
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    Synchronize Stream 0
  FOR i83 in iS0{8}:
    SetCurrentStream to Stream ( i83 % numberOfStreams )
    T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1})
       = select( T0_g___bfloat[iS0{8}, ideviceIdx.x1{2}, iS2{64}, iS3{1024}] (DeviceMesh{0 1}), axis = iS0{8}, index = i83 )
    T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1})
       = select( T4_g___bfloat[iS12{8}, iS13{2}, iS14{64}, iS15{1024}] (DeviceMesh{0 1}), axis = iS12{8}, index = i83 )
    Communication 36 (type=Allgather, team=(0 1), input=T5_l___bfloat[ideviceIdx.x16{2}, iS17{64}, iS18{1024}] (DeviceMesh{0 1}), output=T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}), backend=NCCL)
    Wait Communication 36
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = select( T3_g___bfloat[iStream7{8}, iS8{2}, iS9{64}, iS10{1024}, rS11{1024}] (DeviceMesh{0 1}), axis = iStream7{8}, index = i83 )
    T7_l___bfloat[iS22{2}, iS23{64}, iS24{1024}] (DeviceMesh{0 1})
       = linear(T6_l___bfloat[iS19{2}, iS20{64}, iS21{1024}] (DeviceMesh{0 1}),
                T1_g___bfloat[iS4{1024}, iS5{1024}] (DeviceMesh{0 1})      ,
          T2_g___bfloat[iS6{1024}] (DeviceMesh{0 1})      )
    SetCurrentStream to Stream 0
    Synchronize Stream ( i83 % numberOfStreams )
} // %HostIrContainer
```

## Performance fix
The obtained nsight profile shows that we achieve perfect overlap:
<img width="1471" alt="Screenshot 2025-04-10 at 15 34 22"
src="https://github.com/user-attachments/assets/fe1d6242-c518-42c5-93e9-b2d1e78945a4"
/>

## Further todo:

The current PR modifies the function
`lowerToCollectiveBasedPipelinedGemmComm` which will be removed and
replaced in #4147
We need to port the current patch there.
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from 02d494e to b6c54f2 Compare April 16, 2025 10:50
@samnordmann
Copy link
Collaborator Author

!test

samnordmann added a commit that referenced this pull request Apr 16, 2025
This PR belongs to a series of stacked PRs:
1. **=> You are here: #4144**
2. #4145
3. #4146
4. #4147

# What

- Support for aliases in HostIrContainer. When a Tensor tv1 is marked as
being the alias of tv0, then, at runtime, tv0's concrete data/buffer
will be used for the op. It is a way to reuse buffers that have been
allocated elsewhere within the TensorView's SSA paradigm. Chained
aliasing (tv2-->tv1-->tv0) are supported.
- Fix preallocated outputs in HostIrEvaluator 

# Why

It is necessary for stream parallelization, where typically we allocate
the full output buffer but each stream writes to a slice of this buffer.

# How 

The aliasing is stored in the HostIrContainer through a map.

At the HostIrEvaluator level, instead of operating directly on the
ExprEvaluator to write/read concrete data, we first apply the alias
indirection
samnordmann added a commit that referenced this pull request Apr 16, 2025
This PR belongs to a series of stacked PRs:
1. #4144
2. **=> You are here:** #4145
3. #4146
4. #4147

# What

1. We replace the bool option `SegmentCandidateFinderOptions::
only_segment_resharding_exprs` by a pointer to a predicate function.
This allow the user of the segmented to (optionally) provide a custom
function to be used to decide whether two given groups should be merged.
This achieves better separation of responsibility: with this option, the
segmented is only responsible of applying the segmentation algorithm,
but does not embed the specific rule for merging group which depends on
the application. The specific rule in our context is decided by the Hir
lowering. Imo this refactoring should ideally go further and make the
segmented a more abstract class that would be used in both Host Ir and
FusionExecutorCache lowering but only changing the newly introduced
function pointer.
2. In HIR lowering, we clearly separate (in a distinct for-loop, but
this later will become a preseg pass) the pass that transforms
resharding exprs into a Communication

# Why 
that's a preliminary refactoring useful for more advanced Host Ir
Lowering, notably ParallelType::Stream lowering
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from b8fa021 to 165bd1b Compare April 16, 2025 15:58
@samnordmann
Copy link
Collaborator Author

samnordmann commented Apr 16, 2025

in the last commit move stream_parallel_type to host_ir/pass folder, I did as you suggested in #4145 (comment) and moved the stream lowering pass to a different folder. However, I let it inherit from OptimizationPass for now since I need the "disable guard" already in the current pr, which is already big. This guard is only needed temporarily in the tests which mix multi device and stream parallel type. Soon, the disable guards won't be needed anymore, and I'll remove the inheritance of the pass.

@samnordmann
Copy link
Collaborator Author

I pushed another (hopefully) last commit to further cleanup stream_parallel_type.cpp

@samnordmann
Copy link
Collaborator Author

!test

naoyam and others added 3 commits April 23, 2025 17:00
Fusion segmenter sets aside a certain sequence of unary ops starting
with fusion inputs, which we call forwarding. It effectively works as an
optimization by recomputing (cheap) unary ops instead of passing tensors
from one segment to another.

This PR extends the forwarding optimization to those starting with
factory methods. Here's a motivating example (Litgpt Llama 3 RoPE
backward):


![llama_bwd](https://github.com/user-attachments/assets/84f83b2e-d7c6-4fad-9dee-6cc17578285d)

The `T81` tensor is the output a full op. The tensor is used inside both
yellow and gray segments. The op itself is in the yellow segment, so
it's created inside the yellow segment, and that is passed, through
gmem, to the gray segment. Obviously, cheap ops like this should be just
replicated in the gray segment instead of passing a full tensor. Here's
another way to see it:

```
g{(resize)
group id: 4
inputs:
  T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
  T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] __bfloat
outputs:
  T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat
  T54_g___bfloat[bS233{1}, iS238{8}rf, iS239{4}rf, iS235{8192}, iS236{128}] __bfloat
  T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat


T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
(121)
...
```

And `T81` is used in the next segment of:

```
g{(resize)
group id: 3
inputs:
  T18_g___bfloat[bS79{1}, iS80{32}, iS81{8192}, iS82{128}] __bfloat
  T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] __bfloat
  T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}] __bfloat
  T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}] __bfloat
outputs:
  T75_g___bfloat[bS328{1}, iS329{8}, iS331{6}rf, iS332{8192}, iS333{128}] __bfloat


T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf]
   = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {64, 128, 1} } )
(52)
T55_g___bfloat[bS240{1}, iS241{32}, iS242{8192}, iS244{128}rf]
   = pad( T50_l___bfloat[bS212{1}, iS213{32}, iS214{8192}, iS216{64}rf], {0, 0, 0, 0, 0, 0, 0, 64} )
(61)
T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf]
   = slice( T34_g___bfloat[bS144{1}, iS145{32}, iS146{8192}, iS147{128}], { {0, 1, 1} {0, 32, 1} {0, 8192, 1} {0, 64, 1} } )
(39)
T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}]
   = __bfloat2float(T39_g___bfloat[bS166{1}, iS167{32}, iS168{8192}, iS170{64}rf]);
(44)
T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}]
   = -T43_l_float[bS184{1}, iS185{32}, iS186{8192}, iS187{64}];
(47)
T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}]
   = __float2bfloat(T46_l_float[bS196{1}, iS197{32}, iS198{8192}, iS199{64}]);
(49)
T51_g___bfloat[bS217{1}, iS218{32}, iS219{8192}, iS221{128}rf]
   = pad( T48_g___bfloat[bS204{1}, iS205{32}, iS206{8192}, iS207{64}], {0, 0, 0, 0, 0, 0, 64, 0} )
(54)
T38_l_float[bS162{1}, iS163{32}, iS164{8192}, iS165{128}]
   = __bfloat2float(T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]);
(101)
...
```

There are multiple ways to achieve that. What seems to most make sense
to me is to extend the existing forwarding method to handle cases like
this. The existing method only considers ops starting with fusion
inputs, which do not include factory-created tensors.

This PR applies a small change to the forwarding logic to include
factory ops as well. The end result of this change with the above
example case is that the full result is no longer passed around. Here's
the first segment:

```
g{(resize)
group id: 3
inputs:
  T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat
  T1_g___bfloat[bS3{1}, iS4{32}, iS5{8192}, iS6{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
outputs:
  T49_g___bfloat[bS208{1}, iS209{32}, iS210{8192}, iS211{128}] __bfloat


T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}]
   = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} )
(16)
T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} )
(129)
T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}]
   = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} )
(0)
T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} )
(128)
T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
...
```

Notice that `T81` is no longer a segment output. And the second segment
is:

```
g{(resize)
group id: 4
inputs:
  T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}] __bfloat
  T2_g___bfloat[bS7{1}, iS8{32}, iS9{8192}, iS10{128}] __bfloat
  T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}] __bfloat
outputs:
  T74_g___bfloat[bS321{1}, iS326{8}rf, iS327{4}rf, iS323{8192}, iS324{128}] __bfloat


T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}]
   = broadcast( T3_g___bfloat[bS11{1}, iS12{8192}, iS13{128}], flags = {false, true, false, false} )
(16)
T25_g___bfloat[bS107{1}, bS108{1 ex 32}, iS109{8192}, iS110{128}] = expand( T20_l___bfloat[bS87{1}, bS88{1}, iS89{8192}, iS90{128}], {1, 32, 8192, 128} )
(129)
T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}]
   = broadcast( T0_g___bfloat[bS0{1}, iS1{8192}, iS2{128}], flags = {false, true, false, false} )
(0)
T9_g___bfloat[bS38{1}, bS39{1 ex 32}, iS40{8192}, iS41{128}] = expand( T5_l___bfloat[bS18{1}, bS19{1}, iS20{8192}, iS21{128}], {1, 32, 8192, 128} )
(128)
T81_g___bfloat[bS366{1}, iS367{32}, iS368{8192}, iS369{128}]
   = full({1, 32, 8192, 128}, __bfloat(0));
(121)
...
```
This PR extends the `propagateSharding` presegmentation pass for DID
loop splits.
Key changes:
1. We use TransformPropagator for all expressions except `ViewOp` which
is handled manually since TransformPropagator does not support it
without first propagating the reshape to the producer.
2. `makeReshardingContiguous` sets allocation domain for tvs with device
mesh. Ideally, we need to set it only for global tensors but this is not
known before segmentation, but should be set before segmentation.
3. ~The following tests are modified: See
[discussion](#3838 (comment).
PR #4274 resolved this.

Follow-up PRs:

- `ViewOp` will be handled in a followup PR.
- Currently, we only backpropagate sharding for a tv that does not
already have a device dimension. This can be extended to propagate for
all parallel types not present on the tv. This will be done in a
followup. Backpropagating shardings can incorrectly change DIDx to
serial or modify DIDx to be on another location. `shardAllLike` can be
modified to specify which parallel type to propagate. Since
`insertResharding` and `propagateSharding` require different behavior, I
will handle it in a separate PR.
- Use `TransformReplay::CasP` in lieu of TransformPropagator.
- Propagate DID transforms within `castOp`:
[privatizeUpcast](https://github.com/NVIDIA/Fuser/blob/ed687366cf717837c8ea3e40f56542fec48e1616/csrc/fusion_segmenter.cpp#L4235-L4238)
clones cast operations, which fails segmentation since the transforms
are not replicated.

Findings from experiments:
#3838 (comment)

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
The motivation is to use them in Thunder.
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking great!

FYI, I skipped stream_parallel_type.cpp because I may end up doing loop inlining differently and don't want to slow down your experiments.

@samnordmann
Copy link
Collaborator Author

!test

rdspring1 and others added 3 commits April 24, 2025 06:55
This PR updates the build to use a `pyproject.toml` and isolates the
python bindings into `python` directory.

## Install From Source:
```bash
git clone https://github.com/NVIDIA/Fuser.git
cd Fuser
pip install -r python/requirements.txt

[MAX_JOBS] python setup.py develop [args]  # DEPRECATED
pip install --no-build-isolation -e python -v
```    

## Details
- Moved `csrc/python_frontend` and `nvfuser` to `python`.
- Moved `tools/gen_nvfuser_version.py` and `tools/memory.py` to
`python`.
- Created a new `setup.py` in `python`. This is the new primary
`setup.py`.
- Updated github workflows
- Created symbolic links to support `setup.py` in root directory.

## Changes to argument passing to `root/setup.py` and
`root/python/setup.py`
- `python/utils.py` has the common utilities between `root/setup.py` and
`root/python/setup.py`
- Updated argument parsing to use `argparse` to create a `dataclass`
configuration.
- The `argparse` creates a default `dataclass` if no arguments are not
provided in the command line.
- `NVFUSER_BUILD_ENV_VARS` then overrides the values in the `dataclass`.
- The `root/setup.py` only supports command-line arguments.

---------

Co-authored-by: Wang, Xiao <24860335+xwang233@users.noreply.github.com>
This PR fixes the clang-tidy lintrunner. The `build_dir` argument needs
to change from `./build` to `./python/build`.
#4242 turned on "grid traversal factor" which is a good thing. However,
it exposed a bug in how we limit that factor to prevent overrun in case
the swizzled axis has fewer tiles than the factor. This led to a
regression from 58% to 35% geomean perf compared to eager on H200.

This PR swaps the axes used to compute the number of swizzled tiles and
takes us from a geomean of 35% to 65% on
`benchmarks/python/test_matmul.py` on H200.
@samnordmann
Copy link
Collaborator Author

I may end up doing loop inlining differently

For my own learning, I'd be interested to hear your ideas

wujingyue and others added 4 commits April 25, 2025 12:13
Co-authored-by: root <26priya11@gmail.com>
Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
This is a follow-up to #3906, which added a WAR to #3640. While it's
safe, it turned out it's just too conservative. For example, here's a
concat pattern appearing in the backward of Litgpt Llama RoPE:

```
Inputs:
  T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}]
  T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}]
  T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}]
Outputs:
  T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf]

%kernel_math {
T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}]
   = pad( T0_g___bfloat[bS0{1}, iS1{8}, iS2{4}, iS3{8192}, iS4{128}], {0, 0, 0, 0, 0, 2, 0, 0, 0, 0} )
i31 = 0 + 4;
T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}]
   = pad( T1_g___bfloat[bS5{1}, iS6{8}, bS7{1}, iS8{8192}, iS9{128}], {0, 0, 0, 0, i31, 1, 0, 0, 0, 0} )
i47 = i31 + 1;
T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}]
   = pad( T2_g___bfloat[bS10{1}, iS11{8}, bS12{1}, iS13{8192}, iS14{128}], {0, 0, 0, 0, i47, 0, 0, 0, 0, 0} )
T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}]
   = cat( T3_l___bfloat[bS15{1}, iS16{8}, iS18{6}rf, iS19{8192}, iS20{128}], T4_l___bfloat[bS21{1}, iS22{8}, iS24{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS25{8192}, iS26{128}], T5_l___bfloat[bS27{1}, iS28{8}, iS30{( ( ( 0 + 4 ) + 1 ) + 1 )}rf, iS31{8192}, iS32{128}], 2 )
T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}]
   = Set.Permute( T6_l___bfloat[bS33{1}, iS34{8}, iS35{6}, iS36{8192}, iS37{128}], cache_op=Streaming )
T8_g___bfloat[bS43{1}, iS44{8192}, iS52{6144}rf] = view( T7_l___bfloat[bS38{1}, iS41{8192}, iS39{8}, iS40{6}, iS42{128}] )
} // %kernel_math
```

This is currently taken by the pointwise scheduler, which attempts to
vectorize the innermost ID of the output (i.e., `iS52{6144}`). Since the
resize ops of the three pad ops are reachable from `iS52`, the WAR of
#3640 simply takes them into consideration by calculating gcd with the
left and right expand factors. In this case, since there's an expand
factor of 1, the resulting vectorization factor is also just 1, which is
clearly not what we want. Here, while the resized ID itself is not
vectorizable due to the expand factor of 1, all of the resized tensors
have large enough inner IDs that should allow the maximum vectorization.

To make the WAR a little less conservative, this PR also checks if the
constraint by a Resize expr may be missed by the vectorization analysis.
In the above case, that should not happen as there's only one path
through each of the resize-based tensor ops.

This change is still not able to eliminate false positives completely.
See one of the new tests that is currently disabled.

The codediff results all seem to make sense. http://nv/eFb. Previously
some of the tests did not have vectorization due to the WAR, which is
relaxed in this PR and allows some vectorization.
This PR belongs to a series of stacked PRs:
1. #4144
2. #4145
3. **=> You are here:** #4146
4. #4147

Add support for `LoadStoreOp`, `BinaryOp`, `ReductionOp`, including
support for pre-allocated output, which is not provided by
ExprEvaluator.

---------

Co-authored-by: Jingyue Wu <wujingyue@gmail.com>
An error occurred while trying to automatically change base from host_irs/LoadStore_Reduction_binaryOp_support to main April 27, 2025 11:16
# What
Add a `SelectOp`-like HIR to express indexing into ATen tensor.

# Why
it is used in the context of stream lowering, see
#4147 and
especially the discussion in
#4147 (comment)
samnordmann added a commit that referenced this pull request Apr 28, 2025
porting PR #4147 to here which is based on main

This PR only serves as a reference since it got broken down into the
following PRs to ease reviewing and merging:
1. #4144
2. #4145
3. #4146
4. #4147
@samnordmann samnordmann merged commit d92da5e into host_irs/LoadStore_Reduction_binaryOp_support Apr 28, 2025
60 checks passed
@samnordmann samnordmann deleted the host_irs/stream_lowering/single_device_fusions branch April 28, 2025 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.