Skip to content

[HostIr] Refactor hir optimization pass#4401

Merged
samnordmann merged 9 commits intomainfrom
new_hir_optim_pass_and_refactor
May 14, 2025
Merged

[HostIr] Refactor hir optimization pass#4401
samnordmann merged 9 commits intomainfrom
new_hir_optim_pass_and_refactor

Conversation

@samnordmann
Copy link
Collaborator

What

This PR is only about cleaning and refactoring, no change in the behavior.

  • We create a new base class for HIR lowering optimization pass hir_pass::OptimizationPass in host_ir/pass/optimization_pass.h.
  • We create an option hir_lowering_logging to control debug logging through NVFUSER_DUMP
  • We create a guard to enable/disable a pass
  • We make the existing pass StreamLowering and InsertDeallocations inherit from this pass
  • We factor out converting a resharding op to a communication into a separate pass

Why

preparation for #4387

@samnordmann
Copy link
Collaborator Author

!test

@github-actions
Copy link

github-actions bot commented May 9, 2025

Review updated until commit f76fc10

Description

  • Created a new base class OptimizationPass for HIR lowering optimization passes.

  • Introduced ConvertOpToCommunication pass to handle converting operations to communications.

  • Moved communication lowering logic to lower_to_communication.cpp.

  • Refactored HostIrLower to use convertSingleOpToCommunication.


Changes walkthrough 📝

Relevant files
Enhancement
19 files
executor.cpp
Updated to use `convertSingleOpToCommunication`.                 
+4/-3     
lower.cpp
Removed communication lowering logic, updated to use
convertSingleOpToCommunication.
+4/-556 
lower_to_communication.cpp
Added new file with communication lowering logic.               
+550/-0 
convert_op_to_communication.cpp
Added new pass for converting operations to communications.
+75/-0   
insert_deallocations.cpp
Updated to use new `InsertDeallocations` class.                   
+9/-5     
stream_parallel_type.cpp
Updated to use new `OptimizationPass` base class.               
+3/-3     
options.cpp
Added new option `host_ir_lowering_logging`.                         
+1/-0     
fusion_kernel_runtime.cpp
Updated to use new convertSingleOpToCommunication and
InsertDeallocations class.
+7/-4     
fusion_definition.cpp
Updated to use new `OptimizationPassGuard` with `StreamParallelType`.
+1/-1     
test_host_ir_stream_lowering.cpp
Updated to use new `StreamParallelType` class.                     
+15/-22 
test_multidevice_host_ir.cpp
Updated to use new OptimizationPassGuard with StreamParallelType and
ReorderShardedAxisPass.
+11/-5   
lower.h
Added declaration for `convertSingleOpToCommunication`.   
+9/-1     
lower_to_communication.h
Added new header for communication lowering logic.             
+20/-0   
convert_op_to_communication.h
Added new header for `ConvertOpToCommunication` pass.       
+36/-0   
insert_deallocations.h
Updated to use new `InsertDeallocations` class.                   
+12/-3   
optimization_pass.h
Added new base class `OptimizationPass`.                                 
+88/-0   
stream_parallel_type.h
Updated to use new `OptimizationPass` base class.               
+6/-7     
options.h
Added new option `host_ir_lowering_logging`.                         
+1/-0     
CMakeLists.txt
Added new source files for communication lowering and passes.
+2/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Code Duplication

The code for lowering expressions to communications was moved from HostIrLower::lower to convertSingleOpToCommunication in lower_to_communication.cpp. However, the original HostIrLower::lower method still contains a significant amount of duplicate code that should be removed to avoid redundancy.

bool HostIrLower::canLower(Expr* expr, bool ignore_inner_resharding) {
  if (!isResharding(expr)) {
    return true;
  }
  if (!ir_utils::isTvOp(expr)) {
    return false;
  }
  if (auto* reduction = dynamic_cast<ReductionOp*>(expr)) {
    if (!ignore_inner_resharding && isInnerResharding(expr)) {
      return false;
    }
    auto in = reduction->in()->as<TensorView>();
    auto out = reduction->out()->as<TensorView>();
    // get the reduced axis
    std::vector<IterDomain*> reduction_axis;
    std::copy_if(
        out->getLogicalDomain().begin(),
        out->getLogicalDomain().end(),
        std::back_inserter(reduction_axis),
        [](IterDomain* id) { return id->isReduction(); });
    // check whether the reduction involves only one axis
    if (reduction_axis.size() != 1) {
      return false;
    }
    // We check whether the reduced axis is sharded on the input
    const auto c2p_map =
        PairwiseLogicalDomainMap(in, out).mapConsumerToProducer();
    auto c2p_map_it = c2p_map.find(reduction_axis.at(0));
    return c2p_map_it != c2p_map.end() && c2p_map_it->second->isDeviceDim();
  } else if (auto* ldst = dynamic_cast<LoadStoreOp*>(expr)) {
    if (!ignore_inner_resharding && isInnerResharding(expr)) {
      return false;
    }
    return ldst->as<LoadStoreOp>()->opType() == LoadStoreOpType::Set;
  } else if (auto* matmul = dynamic_cast<MatmulOp*>(expr)) {
    // For now we only support out = matmul(a,b) when b, out are fully
    // replicated, a is sharded on axis 1, and out i stream-parallelized on axis
    // 0.
    return !isSharded(matmul->inB()) && !isSharded(matmul->out()) &&
        matmul->inA()->axis(0)->getParallelType() == ParallelType::Serial &&
        getShardedLogicalAxis(matmul->inA(), ParallelType::DIDx) == 1 &&
        matmul->out()->axis(0)->getParallelType() == ParallelType::Stream;
  } else if (auto* linear = dynamic_cast<LinearOp*>(expr)) {
    // For now we only support out = linear(a, b, bias) when b, bias, and out
    // are fully replicated, a is sharded on axis 1, and out i
    // stream-parallelized on axis 0.
    auto* a = linear->inA()->as<TensorView>();
    auto* b = linear->inB()->as<TensorView>();
    auto* bias =
        (linear->has_bias() ? linear->bias()->as<TensorView>() : nullptr);
    auto* out = linear->out()->as<TensorView>();
    return !isSharded(b) && !(linear->has_bias() && isSharded(bias)) &&
        !isSharded(out) &&
        a->axis(0)->getParallelType() == ParallelType::Serial &&
        getShardedLogicalAxis(a, ParallelType::DIDx) == 1 &&
        out->axis(0)->getParallelType() == ParallelType::Stream;
  }
  return false;
}

bool HostIrLower::isLowerableAsStandaloneHostOp(Expr* expr) {
  if (expr->isOneOf<
          MatmulOp,
          SliceOp,
          SelectOp,
          LinearOp,
          BinaryOp,
          ReductionOp,
          Communication,
          P2PCommunication>()) {
    return true;
  }

  // Lower as standalone op "set" ops, i.e., LoadStoreOp of "Set" type with no
  // permute
  if (expr->isA<LoadStoreOp>()) {
    auto* load_store = expr->as<LoadStoreOp>();
    if (load_store->opType() == LoadStoreOpType::Set &&
        load_store->out()->isA<TensorView>()) {
      auto* tv = load_store->out()->as<TensorView>();
      // If the output tensor has no root, it means it has no permute
      if (!tv->hasRoot()) {
        return true;
      }
    }
  }

  return false;
}

bool HostIrLower::shouldMergeSegmentedGroups(
    SegmentedGroup* group1,
    SegmentedGroup* group2) {
  for (auto group : {group1, group2}) {
    for (Expr* expr : group->exprs()) {
      if (isLowerableAsStandaloneHostOp(expr)) {
        return false;
      }
    }
  }
  return true;
}

std::unique_ptr<hir::HostIrContainer> HostIrLower::lower(
    std::unique_ptr<Fusion> fusion,
    DeviceIdxType my_device_index) {
  // Sharding PreSegmenter passes.
  // Note: passes run before PreSegmenter optimization passes.
  preseg_passes::OptimizationPass<
      preseg_passes::PropagateShardingsPass>::runPass(fusion.get());
  preseg_passes::OptimizationPass<
      preseg_passes::InsertReshardingsPass>::runPass(fusion.get());
  preseg_passes::OptimizationPass<
      preseg_passes::ReorderShardedAxisPass>::runPass(fusion.get());
  preseg_passes::OptimizationPass<
      preseg_passes::MakeReshardingContiguousPass>::runPass(fusion.get());

  // Performs segmentation at the inter-device communications
  // Each SegmentedGroup represents a pipeline's stage, and can be either
  // 1) a Fusion which doesn't involve inter-device communication
  // 2) a Fusion comprised of one Expr, representing inter-device communication
  SegmentCandidateFinderOptions options{
      .run_translate_welford = false,
      .run_combine_reductions = false,
      .run_herrmann_merge = true,
      .run_final_merge = true,
      .custom_should_merge_groups = &shouldMergeSegmentedGroups};
  std::unique_ptr<SegmentedFusion> staged_fusion =
      SegmentCandidateFinder::segment(
          std::move(fusion), KernelArgumentHolder(), options, true);
  // Infer a topologically ordered traversal of the segmented fusion to
  // determine the order for launching the kernels/comms
  RuntimeWorkSpace workspace;
  prepareRuntimeOrder(staged_fusion.get(), workspace);
  // Create the HostIrContainer representing the host program. Each segment of
  // the segmented fusion will be translated to a HostIR
  auto hic = std::make_unique<hir::HostIrContainer>();
  FusionGuard fg(hic.get());
  IrCloner ir_cloner(hic.get());
  auto clone =
      [&ir_cloner](const std::vector<Val*>& vals) -> std::vector<Val*> {
    std::vector<Val*> cloned_vals(vals.size());
    std::transform(
        vals.begin(), vals.end(), cloned_vals.begin(), [&ir_cloner](Val* val) {
          return ir_cloner.clone(val);
        });
    return cloned_vals;
  };

  for (auto group : workspace.group_run_order) {
    NVF_ERROR(!group->exprs().empty(), "invalid segmentation");
    if (involvedDevices(group->exprs().at(0)).count(my_device_index) == 0) {
      continue;
    }
    // we decide whether to insert the Expr as a standalone op in the
    // HostIRContainer, which will result in using ATen Op to evaluate it --
    // or, alternatively, to wrap them into a PostOnStream(HostUnit(.)) which
    // will result in a kernel code generation.
    if (std::all_of(
            group->exprs().begin(),
            group->exprs().end(),
            isLowerableAsStandaloneHostOp)) {
      NVF_ERROR(
          group->exprs().size() == 1,
          "Expr executed as a standalone op cannot be fused");
      hic->pushBackTopLevelExprs(ir_cloner.clone(group->exprs().at(0)));
    } else {
      auto host_unit = IrBuilder::create<hir::HostUnit>(
          staged_fusion->makeFusion(group).second);
      auto post_on_stream = IrBuilder::create<hir::PostOnStream>(
          host_unit, clone(group->inputs()), clone(group->outputs()));
      hic->pushBackTopLevelExprs(post_on_stream);
    }
  }
  for (auto input : staged_fusion->inputs()) {
    hic->addInput(ir_cloner.clone(input));
  }
  for (auto output : staged_fusion->outputs()) {
    hic->addOutput(ir_cloner.clone(output));
  }

  for (auto tv : hic->allTvs()) {
    // set all host tensors to global memory type. This must be the case by
    // definition of a host tensor, and setting the memory type to global is
    // also required to avoid Allocate HIR nodes to throw
    tv->setMemoryType(MemoryType::Global);
  }

  hir_pass::StreamParallelType().runPass(hic.get());

  hir_pass::ConvertOpToCommunication(params_).runPass(hic.get());

  return hic;
}

} // namespace nvfuser
Error Handling

The convertSingleOpToCommunication function does not handle errors in the same way as the original HostIrLower::lower method. Ensure that all error conditions are properly checked and handled to maintain consistency.

Code Duplication

The InsertDeallocations class now includes a check for hir::Deallocate that was previously in the insertDeallocations function. This check should be reviewed to ensure it is not redundant and that the function behaves as expected.

        !expr->isA<hir::Deallocate>(),
        "Expected hostir container to not have deallocate, but found one anyways");
  });
  std::unordered_map<TensorView*, int64_t> last_use;
  for (auto&& [i, expr] : enumerate(top_level_exprs)) {
    for (auto* val : expr->inputs()) {
      if (!val->isA<TensorView>()) {
        continue;
      }
      auto tv = val->as<TensorView>();
      last_use[tv] = i;
    }
  }

  std::vector<std::pair<int64_t, TensorView*>> last_use_by_index;
  last_use_by_index.reserve(last_use.size());
  for (auto&& [tv, i] : last_use) {
    last_use_by_index.emplace_back(i, tv);
  }
  std::sort(last_use_by_index.begin(), last_use_by_index.end());
  for (auto&& [i, tv] : last_use_by_index | std::views::reverse) {
    auto* deallocate = IrBuilder::create<hir::Deallocate>(tv);
    hic->insertExprAfter(i, deallocate);
  }
}

} // namespace nvfuser::hir_pass

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann force-pushed the new_hir_optim_pass_and_refactor branch from 7511c7c to fdd56a1 Compare May 9, 2025 13:17
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested a review from wujingyue May 9, 2025 13:24
const HostIrLowerParams& params = HostIrLowerParams())
: params_(params) {}

static std::vector<Expr*> ConvertSingleOpToCommunication(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the main stack, implemented in fusion_kernel_runtime.cpp, I consider this to be part of the fusion-IR-to-host-IR lowering (aka host IR lowering) instead of a host-IR-to-host-IR transformation pass (aka host IR pass).

Therefore, I prefer leaving this function in lower.h so it can be used in both host IR lowering and host IR passes.

Copy link
Collaborator Author

@samnordmann samnordmann May 12, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the main stack, implemented in fusion_kernel_runtime.cpp, I consider this to be part of the fusion-IR-to-host-IR lowering (aka host IR lowering) instead of a host-IR-to-host-IR transformation pass (aka host IR pass)

Let me motivate the reason why I moved it:

  1. I personally I find it cleaner and more logical to keep this function separated, and close to the pass that uses it.
  2. The main stack should eventually use the pass and not the function. The function should eventually not even be exposed
  3. it is going to be a host-IR-to-host-IR pass after stream lowering will be integrated. To see it as part of lowering and not host IR pass is just because of the current implementation but is not necessary.

Therefore, I prefer leaving this function in lower.h so it can be used in both host IR lowering and host IR passes.

Would you be ok at least to put it a separate file host_ir/lower_to_communication.h|cpp? (as was named in the past)

Otherwise it would means that HostIrLower::lower from host_ir/lower.h calls the pass host_ir/pass/convert_op_to_communication.h which in turns calls HostIrLower:: ConvertSingleOpToCommunication from host_ir/lower.h, which is kind of cyclical and therefore misleading.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you be ok at least to put it a separate file host_ir/lower_to_communication.h|cpp? (as was named in the past)

SGTM!

virtual ~OptimizationPass() = default;

protected:
virtual void passImplementation(Fusion* fusion) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defeats the purpose of "recurring template pattern". I realized

virtual ~OptimizationPass() = default;
might have misled you and is probably (I'll give that a quick shot) unnecessary. See
DerivedClass::runPass(fusion);

for an example of how to avoid virtual methods.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This defeats the purpose of "recurring template pattern"

you're right, thanks. My problem is that I do not know how to deal with that ConvertOpToCommunication pass (and others in the future) has a non-trivial constructor. Maybe I am just missing something.

what about not using recurring template pattern but and keep the runPass in the base class and passImplementation as a private member of the derived class ?

Let me know what you suggest

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would the following work for non-trivial constructors?

template <typename Derived>
class Base {
public:
    void do_something() {
        static_cast<Derived*>(this)->implementation();
    }
};

class Derived : public Base<Derived> {
public:
    Derived(int x, std::string name) : x_(x), name_(std::move(name)) {}

    void implementation() {
        std::cout << "x: " << x_ << ", name: " << name_ << "\n";
    }

private:
    int x_;
    std::string name_;
};

(Again, it uses no virtual methods here)

Alternatively, if you don't plan to use CRTP for your host IR passes, please build hir::OptimizationPass as a non-template class and simplify. If you choose to do this, keep insertDeallocations as a helper function and call it from your hir::OptimizationPass. This way, I can wrap it differently if/when I need to.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I was not so familiar with CRTP, now it is clearer. I have removed the virtual method. Let me know.

@samnordmann samnordmann requested a review from wujingyue May 12, 2025 15:12
Copy link
Collaborator

@wujingyue wujingyue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM otherwise!

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann requested a review from wujingyue May 13, 2025 17:33
@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann merged commit 70124e6 into main May 14, 2025
48 of 49 checks passed
@samnordmann samnordmann deleted the new_hir_optim_pass_and_refactor branch May 14, 2025 16:12
samnordmann added a commit that referenced this pull request May 28, 2025
Stacked on top of:
- #4401
- #4402 

Implements stream lowering to collective based pipelines.


# Test dumps:
generated with the command line:
```
mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
```

```
MultiDeviceStreamParallelTypeTest.Allgather

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.Allreduce

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.ReduceScatter

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
    Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 48
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.AG_matmul

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
    Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 59
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.matmul_AR

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
    Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 64
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer


MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
  T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
    Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 80
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )

HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
  T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
Outputs:
  T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})

%kernel {
T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
   = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
} // %kernel
} // %HostIrContainer
```
nsarka pushed a commit to nsarka/Fuser that referenced this pull request Jul 28, 2025
Stacked on top of:
- NVIDIA#4401
- NVIDIA#4402 

Implements stream lowering to collective based pipelines.


# Test dumps:
generated with the command line:
```
mpirun -x NVFUSER_DUMP=host_ir -np 8 $BUILD_DIRECTORY/test_multidevice --gtest_filter=*MultiDeviceStreamParallelTypeTest.*
```

```
MultiDeviceStreamParallelTypeTest.Allgather

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i2 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx2{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx2{i0}, iS3{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx2{i0}, index = StreamIdx )
    T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i2, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x4{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[iS5{i2}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.Allreduce

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i0 * i3 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx3{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx3{i0}, rS4{i2}, iS5{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx3{i0}, index = StreamIdx )
    T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=i3, zero_init=false, resets_to_zero=false)
    Communication 39 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x6{i2}, iS7{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS8{i2}, iS9{i3}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 39
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.ReduceScatter

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx4{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T1_g_float[iStreamIdx4{i0}, rS5{i2}, ideviceIdx.x6{i3}, iS7{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx4{i0}, index = StreamIdx )
    T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i4 ), zero_init=false, resets_to_zero=false)
    Communication 48 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T2_g_float[ideviceIdx.x8{i2}, iS9{i3}, iS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T3_g_float[rS11{i2}, ideviceIdx.x12{i3}, iS13{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 48
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.AG_matmul

%HostIrContainer { (T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
  T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i2 ) * i3 ) * i6 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx11{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[iS0{i0}, ideviceIdx.x1{i2}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx11{i0}, iS12{i2}, iS13{i3}, iS14{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx11{i0}, index = StreamIdx )
    T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i2 * i3 ) * i4 ), zero_init=false, resets_to_zero=false)
    Communication 59 (type=Allgather, team=(0 1 2 3 4 5 6 7), input=T4_g_float[ideviceIdx.x15{i2}, iS16{i3}, iS17{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 59
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[iStreamIdx6{i0}, iS7{i2}, iS8{i3}, iS9{i6}, rS10{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx6{i0}, index = StreamIdx )
    T6_l_float[iS21{i2}, iS22{i3}, iS23{i6}, rS24{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_g_float[iS18{i2}, iS19{i3}, iS20{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[iS4{i5}, iS5{i6}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer



MultiDeviceStreamParallelTypeTest.matmul_AR

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i2 * i0 ) * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i0 * i3 ) * i7 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx7{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x8{i2}, iStreamIdx7{i0}, iS9{i3}, iS10{i7}, rS11{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx7{i0}, index = StreamIdx )
    T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T4_l_float[ideviceIdx.x16{i2}, iS17{i3}, iS18{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T1_g_float[ideviceIdx.x4{i5}, iS5{i6}, iS6{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[iStreamIdx12{i0}, rS13{i2}, iS14{i3}, iS15{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx12{i0}, index = StreamIdx )
    T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( i3 * i7 ), zero_init=false, resets_to_zero=false)
    Communication 64 (type=Allreduce, team=(0 1 2 3 4 5 6 7), input=T5_g_float[ideviceIdx.x19{i2}, iS20{i3}, iS21{i7}, rS22{i4}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T6_g_float[rS23{i2}, iS24{i3}, iS25{i7}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 64
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )
} // %HostIrContainer


MultiDeviceStreamParallelTypeTest.matmul_RS_through_bcast

%HostIrContainer { (T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) -> (T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})) :
  PostOnStream (HostUnit0, Inputs:{T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }, Outputs:{T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), })
  T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( ( i2 * i0 ) * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( ( i0 * i3 ) * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
  GetCurrentStream into Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    Synchronize Stream 0
  FOR StreamIdx in iStreamIdx13{i0}:
    SetCurrentStream to Stream ( StreamIdx % numberOfStreams )
    T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T0_g_float[ideviceIdx.x1{i2}, iS0{i0}, iS2{i3}, iS3{i4}, iS4{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iS0{i0}, index = StreamIdx )
    T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = bS8{1}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T3_g_float[ideviceIdx.x14{i2}, iStreamIdx13{i0}, iS15{i3}, iS16{i4}, iS17{i8}, rS18{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx13{i0}, index = StreamIdx )
    T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = matmul(T5_l_float[ideviceIdx.x24{i2}, iS25{i3}, iS26{i4}, iS27{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}),
                T6_l_float[ideviceIdx.x28{i6}, bS29{1}, iS30{i7}, iS31{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}))
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
       = HirAliasSelect( T4_g_float[iStreamIdx19{i0}, rS20{i2}, ideviceIdx.x21{i3}, iS22{i4}, iS23{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), axis = iStreamIdx19{i0}, index = StreamIdx )
    T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}) = ALLOCATE(buffer=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), mem_type=global, size=( ( i3 * i4 ) * i8 ), zero_init=false, resets_to_zero=false)
    Communication 80 (type=ReduceScatter, team=(0 1 2 3 4 5 6 7), input=T7_g_float[ideviceIdx.x32{i2}, iS33{i3}, iS34{i4}, iS35{i8}, rS36{i5}] (DeviceMesh{0 1 2 3 4 5 6 7}), output=T8_g_float[rS37{i2}, ideviceIdx.x38{i3}, iS39{i4}, iS40{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), backend=NCCL)
    Wait Communication 80
    SetCurrentStream to Stream 0
    Synchronize Stream ( StreamIdx % numberOfStreams )

HostUnit0: Inputs={T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), } -> Outputs={T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), }Inputs:
  T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
Outputs:
  T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})

%kernel {
T2_g_float[ideviceIdx.x9{i6}, bS8{1}, bS10{1}, iS11{i7}, iS12{i8}] (DeviceMesh{0 1 2 3 4 5 6 7})
   = broadcast( T1_g_float[ideviceIdx.x5{i6}, iS6{i7}, iS7{i8}] (DeviceMesh{0 1 2 3 4 5 6 7}), flags = {true, false, true, false, false} )
} // %kernel
} // %HostIrContainer
```
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants