Skip to content

[Host Ir] stream lowering, first milestone: single device fusion#4148

Merged
samnordmann merged 49 commits intomainfrom
host_irs/stream_lowering/single_device_fusions
Apr 28, 2025
Merged

[Host Ir] stream lowering, first milestone: single device fusion#4148
samnordmann merged 49 commits intomainfrom
host_irs/stream_lowering/single_device_fusions

Conversation

@samnordmann
Copy link
Collaborator

@samnordmann samnordmann commented Mar 26, 2025

porting PR #4147 to here which is based on main

This PR only serves as a reference since it got broken down into the following PRs to ease reviewing and merging:

  1. [Host irs] alias and preallocated output support #4144
  2. [Host Ir] refactor and cleanup lowering and segmentation #4145
  3. [Host ir] support for set reduce and binary op #4146
  4. [Host irs] Stream lowering of single device fusions #4147

@github-actions
Copy link

github-actions bot commented Mar 26, 2025

Review updated until commit 35ff4da

Description

  • Added StreamParallelType pass for single device fusions in Host IR

  • Enhanced HostIrLower to handle stream parallelization

  • Introduced tests for stream parallel lowering

  • Updated parallel type string representation


Changes walkthrough 📝

Relevant files
Enhancement
10 files
container.cpp
Conditional alias printing in HostIrContainer                       
+6/-4     
lower.cpp
Added StreamParallelType pass invocation and memory type setting
+10/-0   
stream_parallel_type.cpp
Implemented StreamParallelType pass logic                               
+440/-0 
type.cpp
Updated parallel type string for Stream                                   
+1/-1     
fusion_definition.cpp
Temporarily disabled StreamParallelType pass in MultiDeviceExecutor
+5/-0     
test_multidevice_host_ir.cpp
Temporarily disabled StreamParallelType pass in tests       
+8/-0     
executor.h
Added container access method in HostIrEvaluator                 
+4/-0     
stream_parallel_type.h
Added StreamParallelType pass declaration                               
+36/-0   
internal_nodes.h
Added iterDomain method to ForLoop                                             
+4/-0     
executor.h
Added hostIrEvaluator access method in MultiDeviceExecutor
+4/-0     
Tests
1 files
test_host_ir_stream_lowering.cpp
Added tests for StreamParallelType pass                                   
+814/-0 
Cleanup
1 files
optimization_pass.h
Removed unused FusionPass typedef                                               
+0/-2     
Configuration changes
1 files
CMakeLists.txt
Added StreamParallelType pass and test files                         
+2/-0     

PR Reviewer Guide 🔍

Here are some key observations to aid the review process:

🧪 PR contains tests
⚡ Recommended focus areas for review

Performance Concern

The implementation of the StreamParallelType pass should be evaluated for performance. Ensure that the pass does not introduce significant overhead and that the performance gains from stream parallelization are justified.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on

#include <host_ir/container.h>
#include <host_ir/lower.h>
#include <host_ir/pass/stream_parallel_type.h>
#include <id_model/id_model.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <ir/internal_base_nodes.h>
#include <ir/utils.h>
#include <kernel_ir.h>
#include <ops/all_ops.h>
#include <ops/utils.h>

namespace nvfuser::hir {

namespace {

// Finds the stream axis in a tensor's domain. There should be at most one
// stream axis.
IterDomain* getStreamAxis(const std::vector<IterDomain*>& domain) {
  IterDomain* ret = nullptr;
  for (auto id : domain) {
    if (id->getParallelType() == ParallelType::Stream) {
      NVF_CHECK(
          ret == nullptr,
          "Expected at most one stream axis in the domain, but found ",
          id,
          " and ",
          ret);
      ret = id;
    }
  }
  return ret;
}

// Validates that a stream axis is valid in a tensor
void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) {
  // Find the stream axis in the logical domain
  auto it_logical_stream_axis = std::find(
      tv->getLogicalDomain().begin(),
      tv->getLogicalDomain().end(),
      stream_axis);

  // Verify stream axis is not split/merged
  NVF_ERROR(
      it_logical_stream_axis != tv->getLogicalDomain().end(),
      "Cannot stream parallelize on a split/merge axis ",
      stream_axis);

  // Verify stream axis is an iteration or broadcast axis
  NVF_CHECK(
      stream_axis->getIterType() == IterType::Iteration ||
          stream_axis->getIterType() == IterType::Broadcast,
      "Stream axis ",
      stream_axis,
      " should be an iteration or broadcast axis.");
}

// Checks if two iteration domains are mapped in the ID model
bool areIdsMapped(const IdModel& id_model, IterDomain* id1, IterDomain* id2) {
  return id_model.idGraph(IdMappingMode::BROADCAST)
      .disjointValSets()
      .strictAreMapped(id1, id2);
}

// Determines if a stream-parallel for-loop can be merged with the previous one
bool canMergeWithPreviousForLoop(
    const std::vector<Expr*>& new_top_level_exprs,
    IterDomain* stream_axis,
    const IdModel& id_model) {
  return !new_top_level_exprs.empty() &&
      new_top_level_exprs.back()->isA<ForLoop>() &&
      areIdsMapped(
          id_model,
          stream_axis,
          new_top_level_exprs.back()->as<ForLoop>()->iterDomain());
}

// Finds where a stream axis appears in a tensor's logical domain
int64_t findStreamAxisIndex(
    const TensorView* tv,
    IterDomain* stream_axis,
    const IdModel& id_model) {
  int64_t stream_id_logical_index = -1;
  for (auto id : tv->getLoopDomain()) {
    if (areIdsMapped(id_model, stream_axis, id)) {
      // Verify only one stream axis exists
      NVF_CHECK(
          stream_id_logical_index == -1,
          "Expected at most one axis mapping to the stream axis ",
          stream_axis,
          " in the tensor ",
          tv,
          " loop's domain ",
          tv->getLoopDomain());

      // Find stream axis in logical domain
      auto it_stream_id_logical = std::find(
          tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), id);
      NVF_CHECK(
          it_stream_id_logical != tv->getLogicalDomain().end(),
          "Expected to find ",
          id,
          " in ",
          tv,
          "'s logical domain ",
          tv->getLogicalDomain());
      stream_id_logical_index =
          std::distance(tv->getLogicalDomain().begin(), it_stream_id_logical);
    }
  }
  return stream_id_logical_index;
}

// Cache for tensor slicing operations in stream parallelization.
// This cache stores previously created sliced versions of tensors to avoid
// redundant slicing operations. A sliced tensor is created by removing a
// specific axis (stream axis) from the tensor's domain and creating a new
// tensor that represents a slice of the original tensor at a given index.
// The cache key is a tuple of (original tensor, axis index to remove, slice
// index).
struct TensorSlicingCache {
  // Type aliases
  using Key = std::tuple<TensorView*, int64_t, Val*>;

  // Custom hash function for the tuple used as cache key
  struct Hash {
    size_t operator()(const Key& t) const {
      auto [tv, idx, val] = t;
      return std::hash<TensorView*>{}(tv) ^ std::hash<int64_t>{}(idx) ^
          std::hash<Val*>{}(val);
    }
  };

  // Map type for storing cached sliced tensors
  using Map = std::unordered_map<Key, hir::HirAliasSelect*, Hash>;

  // Get the expr producing the indexed version of a tensor. If the expr already
  // exists in the cache, returns the cached version. Otherwise, creates a new
  // expr, producing a tensor "selected" on its dimension `stream_axis_index` at
  // index `index`. Returns a pair of (expr, is_new) where is_new indicates
  // whether the expr was newly created.
  std::pair<hir::HirAliasSelect*, bool> get(
      TensorView* tensor,
      int64_t stream_axis_index,
      Val* index) {
    auto key = std::make_tuple(tensor, stream_axis_index, index);
    auto it = cache_.find(key);
    if (it != cache_.end()) {
      return {it->second, false};
    }

    auto dom = tensor->getLogicalDomain();
    std::vector<IterDomain*> new_root;
    new_root.reserve(dom.size() - 1);

    for (auto i : arange((int64_t)dom.size())) {
      if (i != stream_axis_index) {
        new_root.emplace_back(dom[i]->cloneWithoutRFactor());
      }
    }

    auto td = IrBuilder::create<TensorDomain>(
        new_root, TensorDomain::getContiguityFilledWith(new_root, true));
    auto out = IrBuilder::create<TensorView>(td, *tensor->getDataType());
    auto result = IrBuilder::create<hir::HirAliasSelect>(
        tensor, out, stream_axis_index, index);

    cache_[key] = result;
    return {result, true};
  }

 private:
  Map cache_; // Storage for cached sliced tensors
};

// Step 1: Group expressions into stream-parallel regions
std::vector<Expr*> groupStreamParallelRegions(
    const std::vector<Expr*>& top_level_exprs,
    const IdModel& id_model) {
  std::vector<Expr*> new_top_level_exprs;

  for (auto* expr : top_level_exprs) {
    // Skip expressions with no outputs
    if (expr->outputs().size() == 0) {
      new_top_level_exprs.push_back(expr);
      continue;
    }

    // Each expression should have exactly one output
    NVF_CHECK(
        expr->outputs().size() == 1,
        "Each expr should have at most one output.");

    // Get the output tensor and check for stream parallelization
    TensorView* output = expr->output(0)->as<TensorView>();
    IterDomain* stream_axis = getStreamAxis(output->getLoopDomain());

    // If no stream axis found, keep the expression as is
    if (stream_axis == nullptr) {
      new_top_level_exprs.push_back(expr);
      continue;
    }

    // Verify that the expression can be handled as a standalone host operation
    NVF_ERROR(
        HostIrLower::isLowerableAsStandaloneHostOp(expr),
        "Stream parallel type not supported for expr ",
        expr);

    // Validate stream axis
    validateStreamAxis(stream_axis, output);

    // Check if we can merge this expression with the previous for-loop
    if (canMergeWithPreviousForLoop(
            new_top_level_exprs, stream_axis, id_model)) {
      // Merge with existing for-loop by adding the expression to its body
      new_top_level_exprs.back()->as<ForLoop>()->body().push_back(expr);
    } else {
      // Create a new for-loop for stream parallelization
      auto* for_loop = IrBuilder::create<ForLoop>(
          stream_axis,
          /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream),
          /*start=*/FusionGuard::getCurFusion()->zeroVal(),
          /*stop=*/stream_axis->extent(),
          /*step=*/FusionGuard::getCurFusion()->oneVal(),
          /*vectorize=*/false,
          /*vectorize_shift=*/nullptr,
          /*unroll_required=*/false,
          CircularBufferLoopStage::NotApplicable,
          /*circular_buffer_loop_stage_depth=*/0);
      // Add the expression to the new for-loop's body
      for_loop->body().push_back(expr);
      new_top_level_exprs.push_back(for_loop);
    }
  }

  return new_top_level_exprs;
}

// Helper function to add allocations for tensors that need them
std::vector<Expr*> addTensorAllocations(
    std::vector<Expr*> top_level_exprs,
    const IdModel& id_model) {
  std::vector<Expr*> new_top_level_exprs;

  for (auto* expr : top_level_exprs) {
    if (expr->isA<ForLoop>()) {
      // add allocations for tensors produced in the loop that have a stream
      // axes
      auto* for_loop = expr->as<ForLoop>();
      for (auto* body_expr : for_loop->body().exprs()) {
        for (auto* output :
             ir_utils::filterByType<TensorView>(body_expr->outputs())) {
          if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) !=
              -1) {
            new_top_level_exprs.push_back(
                IrBuilder::create<kir::Allocate>(output, MemoryType::Global));
          }
        }
      }
    }
    new_top_level_exprs.push_back(expr);
  }

  return new_top_level_exprs;
}

// Step 3: Process for-loop bodies by slicing tensors
std::vector<Expr*> processForLoopBodies(
    std::vector<Expr*> top_level_exprs,
    const IdModel& id_model) {
  TensorSlicingCache tensor_slicing_cache;

  for (auto* expr : top_level_exprs) {
    if (!expr->isA<ForLoop>()) {
      continue;
    }

    auto* for_loop = expr->as<ForLoop>();
    std::vector<Expr*> new_loop_body;

    // Lambda to process a tensor in a for-loop body
    auto processTensor = [&](Expr*& expr, TensorView* tensor) {
      if (auto stream_idx =
              findStreamAxisIndex(tensor, for_loop->iterDomain(), id_model);
          stream_idx != -1) {
        auto [slicing, is_new] =
            tensor_slicing_cache.get(tensor, stream_idx, for_loop->index());
        if (is_new) {
          new_loop_body.push_back(slicing);
        }
        expr = ir_utils::replaceValInExprInputs(expr, tensor, slicing->out());
        if (expr->outputs().size() > 0 && expr->outputs()[0] == tensor) {
          expr =
              ir_utils::transferDefinitionToNewOutputs(expr, {slicing->out()});
        }
      }
    };

    for (auto* body_expr : for_loop->body().exprs()) {
      for (auto* input :
           ir_utils::filterByType<TensorView>(body_expr->inputs())) {
        processTensor(body_expr, input);
      }
      for (auto* output :
           ir_utils::filterByType<TensorView>(body_expr->outputs())) {
        processTensor(body_expr, output);
      }
      new_loop_body.push_back(body_expr);
    }

    for_loop->body().clear();
    for (auto* expr : new_loop_body) {
      for_loop->body().push_back(expr);
    }
  }

  return top_level_exprs;
}

// Step 4: Add stream management and synchronization
std::vector<Expr*> addStreamManagement(std::vector<Expr*> top_level_exprs) {
  // Process each top-level expression
  for (auto* top_level_expr : top_level_exprs) {
    // Skip non-for-loop expressions
    if (!top_level_expr->isA<ForLoop>()) {
      continue;
    }

    auto* for_loop = top_level_expr->as<ForLoop>();
    std::vector<Expr*> new_loop_body;

    // Get the current stream before entering the loop
    auto* get_current_stream = IrBuilder::create<hir::GetCurrentStream>();
    hir::Stream* original_stream = get_current_stream->stream();
    new_loop_body.push_back(get_current_stream);

    // Set up a new stream for this iteration based on the loop index
    auto* number_of_streams =
        IrBuilder::create<NamedScalar>("numberOfStreams", DataType::Int);
    auto* stream_index = mod(for_loop->index(), number_of_streams);
    auto* stream = IrBuilder::create<hir::Stream>(stream_index);
    auto* set_stream = IrBuilder::create<hir::SetCurrentStream>(stream);
    new_loop_body.push_back(set_stream);

    // Synchronize with the original stream before starting computation
    auto* initial_sync_stream =
        IrBuilder::create<hir::Synchronize>(original_stream);
    new_loop_body.push_back(initial_sync_stream);

    // Add all the expressions to the loop body
    for (auto* expr : for_loop->body().exprs()) {
      new_loop_body.push_back(expr);
    }

    // Restore the original stream and synchronize with the iteration's stream
    auto* set_back_original_stream =
        IrBuilder::create<hir::SetCurrentStream>(original_stream);
    new_loop_body.push_back(set_back_original_stream);
    auto* sync_stream = IrBuilder::create<hir::Synchronize>(stream);
    new_loop_body.push_back(sync_stream);

    // Update the for-loop body with the new expressions
    for_loop->body().clear();
    for (auto* expr : new_loop_body) {
      for_loop->body().push_back(expr);
    }
  }

  return top_level_exprs;
}

} // anonymous namespace

// StreamParallelType pass implementation.
// This pass handles stream parallelization of operations in a fusion.
// It works by:
// 1. Identifying stream-parallelized axes in tensor operations
// 2. Grouping compatible operations into stream-parallel for-loops
// 3. Setting up proper stream synchronization and management
// 4. Adding allocations for tensors that need them
// The pass ensures that:
// - Input tensors don't have stream axes
// - Only one stream axis exists per tensor
// - Stream axes are properly synchronized
// - Operations are correctly grouped into stream-parallel regions
// - The resulting HostIrContainer's top level expression is valid for execution
// and does not contain any stream axes
//
// TODO: Here, we assume that the fusion input is a HostIrContainer and use the
// linear structure of the HostIrContainer::topLevelExpr to greedily merge the
// adjacent compatible stream for-loop bodies. Ideally we should look at the dag
// and use the segmenter.
void StreamParallelType::runPass(Fusion* fusion) {
  // Verify that input tensors don't have stream axes
  NVF_CHECK(
      std::all_of(
          fusion->inputs().begin(),
          fusion->inputs().end(),
          [](Val* input) {
            auto input_tv = dynamic_cast<TensorView*>(input);
            return input_tv == nullptr ||
                getStreamAxis(input_tv->getLoopDomain()) == nullptr;
          }),
      "Expected no stream axis in the TensorView inputs.");

  // Set up the fusion environment and build the ID model
  FusionGuard fg(fusion);
  hir::HostIrContainer* hic = dynamic_cast<hir::HostIrContainer*>(fusion);
  NVF_CHECK(hic, "Expected HostIrContainer");

  IdModel id_model(fusion);
  id_model.buildBroadcastGraph();

  // Step 1: Group expressions into stream-parallel regions
  std::vector<Expr*> top_level_exprs =
      groupStreamParallelRegions(hic->topLevelExprs(), id_model);

  // Step 2: Add allocations for tensors that need them
  top_level_exprs = addTensorAllocations(std::move(top_level_exprs), id_model);

  // Step 3: Process for-loop bodies by slicing tensors
  top_level_exprs = processForLoopBodies(std::move(top_level_exprs), id_model);

  // Step 4: Add stream management and synchronization
  top_level_exprs = addStreamManagement(std::move(top_level_exprs));

  // Update the container's top-level expressions
  hic->resetTopLevelExprs(top_level_exprs);
}

} // namespace nvfuser::hir
Test Coverage

Ensure that the test cases cover a wide range of scenarios, including edge cases and potential failure modes. Verify that the tests are comprehensive and provide sufficient coverage for the StreamParallelType pass.

// clang-format off
/*
 * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES.
 * All rights reserved.
 * SPDX-License-Identifier: BSD-3-Clause
 */
// clang-format on
#include <gtest/gtest.h>

#include <fusion.h>
#include <host_ir/container.h>
#include <host_ir/executor.h>
#include <host_ir/lower.h>
#include <host_ir/pass/stream_parallel_type.h>
#include <ir/all_nodes.h>
#include <ir/builder.h>
#include <kernel_ir.h>
#include <multidevice/executor.h>
#include <ops/all_ops.h>
#include <tests/cpp/utils.h>

#include <algorithm>
#include <iostream>

namespace nvfuser {

namespace hir {

using HirLowerStreamTest = NVFuserTest;

TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv = makeContigTensor(2);
  hic->addInput(tv);
  tv->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

TEST_F(HirLowerStreamTest, Split) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv1->split(0, 2);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

TEST_F(HirLowerStreamTest, Merge) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv1->merge(0, 1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

TEST_F(HirLowerStreamTest, SingleSetOp) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output = hie.runWithInput({{tv0, input}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv1->axis(1)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output = hie.runWithInput({{tv0, input}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(HirLowerStreamTest, SingleBinaryOp) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = makeContigTensor(2);
  TensorView* tv2 = add(tv0, tv1);
  hic->addInput(tv0);
  hic->addInput(tv1);
  hic->addOutput(tv2);
  hic->pushBackTopLevelExprs(tv2->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv2->setMemoryType(MemoryType::Global);
  tv2->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor tv0_input = at::rand({4, 4}, options);
  at::Tensor tv1_input = at::rand({4, 4}, options);
  // std::unordered_map<Val*, PolymorphicValue> inputs = {{tv0, input}};
  auto output = hie.runWithInput({{tv0, tv0_input}, {tv1, tv1_input}})[0]
                    .as<at::Tensor>();
  auto expected_output = tv0_input + tv1_input;
  EXPECT_TRUE(output.equal(expected_output))
      << "Output: " << output << "Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, TwoSetOps) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  hic->addInput(tv0);
  hic->addOutput(tv2);
  hic->pushBackTopLevelExprs(tv1->definition());
  hic->pushBackTopLevelExprs(tv2->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv2->setMemoryType(MemoryType::Global);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv2->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 3);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(2)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output = hie.runWithInput({{tv0, input}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = set(tv2);
  hic->addInput(tv0);
  hic->addOutput(tv3);
  hic->pushBackTopLevelExprs(tv1->definition());
  hic->pushBackTopLevelExprs(tv2->definition());
  hic->pushBackTopLevelExprs(tv3->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv2->setMemoryType(MemoryType::Global);
  tv3->setMemoryType(MemoryType::Global);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv3->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 5);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());
  EXPECT_TRUE(hic->topLevelExprs().at(2)->isA<LoadStoreOp>());
  EXPECT_TRUE(hic->topLevelExprs().at(3)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(4)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output = hie.runWithInput({{tv0, input}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(HirLowerStreamTest, ReductionUnsupported) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = sum(tv0, {0});
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

TEST_F(HirLowerStreamTest, Reduction) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* tv0 = makeContigTensor(3);
  TensorView* tv1 = sum(tv0, {2});
  hic->addInput(tv0);
  hic->addOutput(tv1);
  hic->pushBackTopLevelExprs(tv1->definition());
  tv0->setMemoryType(MemoryType::Global);
  tv1->setMemoryType(MemoryType::Global);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8, 2}, options);
  auto output = hie.runWithInput({{tv0, input}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = input.sum(2);
  EXPECT_TRUE(output.equal(expected_output))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, Matmul_M) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  hic->addInput(a);
  hic->addInput(b);
  hic->addOutput(c);
  hic->pushBackTopLevelExprs(c->definition());
  a->setMemoryType(MemoryType::Global);
  b->setMemoryType(MemoryType::Global);
  c->setMemoryType(MemoryType::Global);
  c->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  constexpr int64_t M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output =
      hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, BatchedMatmul) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* a = makeContigTensor(3);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  hic->addInput(a);
  hic->addInput(b);
  hic->addOutput(c);
  hic->pushBackTopLevelExprs(c->definition());
  a->setMemoryType(MemoryType::Global);
  b->setMemoryType(MemoryType::Global);
  c->setMemoryType(MemoryType::Global);
  c->axis(0)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  constexpr int64_t B = 16, M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({B, M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output =
      hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, Matmul_N) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  hic->addInput(a);
  hic->addInput(b);
  hic->addOutput(c);
  hic->pushBackTopLevelExprs(c->definition());
  a->setMemoryType(MemoryType::Global);
  b->setMemoryType(MemoryType::Global);
  c->setMemoryType(MemoryType::Global);
  c->axis(1)->parallelize(ParallelType::Stream);

  preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get());

  EXPECT_EQ(hic->topLevelExprs().size(), 2);
  EXPECT_TRUE(hic->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(hic->topLevelExprs().at(1)->isA<ForLoop>());

  HostIrEvaluator hie(std::move(hic));

  constexpr int64_t M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output =
      hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(HirLowerStreamTest, Matmul_K) {
  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  hic->addInput(a);
  hic->addInput(b);
  hic->addOutput(c);
  hic->pushBackTopLevelExprs(c->definition());
  a->setMemoryType(MemoryType::Global);
  b->setMemoryType(MemoryType::Global);
  c->setMemoryType(MemoryType::Global);
  c->axis(-1)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

// We don's support PostOnStream because it does not support well pre-allocated
// outputs. There is no strong motivation to support PostOnStream
TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) {
  const std::vector<int64_t> input_sizes = {4, 8, 32};
  const std::vector<int64_t> output_sizes = {
      input_sizes.at(1), input_sizes.at(2)};

  auto get_fusion = [input_sizes]() -> std::unique_ptr<Fusion> {
    auto fusion = std::make_unique<Fusion>();
    FusionGuard fg(fusion.get());

    auto tv0 = makeConcreteTensor(input_sizes);
    auto tv1 = add(tv0, tv0);
    auto tv2 = sum(tv1, {0});
    fusion->addInput(tv0);
    fusion->addOutput(tv2);
    return fusion;
  };

  auto hic = std::make_unique<HostIrContainer>();
  FusionGuard fg(hic.get());

  auto host_unit = IrBuilder::create<HostUnit>(get_fusion());

  IrCloner ir_cloner(hic.get());
  TensorView* input =
      ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0))
          ->as<TensorView>();
  TensorView* output =
      ir_cloner.clone(host_unit->fusion_to_execute()->outputs().at(0))
          ->as<TensorView>();

  std::vector<Val*> inputs = {input};
  std::vector<Val*> outputs = {output};
  auto post_on_stream =
      IrBuilder::create<PostOnStream>(host_unit, inputs, outputs);

  hic->pushBackTopLevelExprs(post_on_stream);

  hic->addInput(input);
  hic->addOutput(output);

  output->axis(-1)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      preseg_passes::OptimizationPass<StreamParallelType>::runPass(hic.get()));
}

} // namespace hir

using MultiDeviceExecutorLowerStreamTest = NVFuserTest;

TEST_F(MultiDeviceExecutorLowerStreamTest, InputsAreNotStreamParallelized) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv = makeContigTensor(2);
  fusion->addInput(tv);
  tv->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Split) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->split(0, 2);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Merge) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->merge(0, 1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOp) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOpNonOutermost) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(1)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, SingleBinaryOp) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = makeContigTensor(2);
  TensorView* tv2 = add(tv0, tv1);
  fusion->addInput(tv0);
  fusion->addInput(tv1);
  fusion->addOutput(tv2);
  tv2->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);

  at::Tensor tv0_input = at::rand({4, 4}, options);
  at::Tensor tv1_input = at::rand({4, 4}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({tv0_input, tv1_input}))[0]
          .as<at::Tensor>();
  auto expected_output = tv0_input + tv1_input;
  EXPECT_TRUE(output.equal(expected_output))
      << "Output: " << output << "Expected: " << expected_output;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, TwoSetOps) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  fusion->addInput(tv0);
  fusion->addOutput(tv2);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv2->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 3);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(2)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = set(tv0);
  TensorView* tv2 = set(tv1);
  TensorView* tv3 = set(tv2);
  fusion->addInput(tv0);
  fusion->addOutput(tv3);
  tv1->axis(0)->parallelize(ParallelType::Stream);
  tv3->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 5);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());
  EXPECT_TRUE(container->topLevelExprs().at(2)->isA<LoadStoreOp>());
  EXPECT_TRUE(container->topLevelExprs().at(3)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(4)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();

  torch::cuda::synchronize();
  EXPECT_TRUE(output.equal(input))
      << "Output: " << output << " Expected: " << input;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, ReductionUnsupported) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 = sum(tv0, {0});
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Reduction) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(3);
  TensorView* tv1 = sum(tv0, {2});
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor input = at::rand({4, 8, 2}, options);
  auto output =
      executor.runWithInput(KernelArgumentHolder({input}))[0].as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = input.sum(2);
  EXPECT_TRUE(output.equal(expected_output))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_M) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);
  c->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  constexpr int64_t M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0]
                    .as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, BatchedMatmul) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* a = makeContigTensor(3);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);
  c->axis(0)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  constexpr int64_t B = 16, M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({B, M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0]
                    .as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_N) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);
  c->axis(1)->parallelize(ParallelType::Stream);

  MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance());

  hir::HostIrContainer* container = executor.hostIrEvaluator()->container();
  EXPECT_EQ(container->topLevelExprs().size(), 2);
  EXPECT_TRUE(container->topLevelExprs().at(0)->isA<kir::Allocate>());
  EXPECT_TRUE(container->topLevelExprs().at(1)->isA<ForLoop>());

  constexpr int64_t M = 8, K = 4, N = 2;
  auto options = at::TensorOptions().device(at::kCUDA, 0);
  at::Tensor a_aten = at::rand({M, K}, options);
  at::Tensor b_aten = at::rand({K, N}, options);
  auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0]
                    .as<at::Tensor>();

  torch::cuda::synchronize();
  auto expected_output = at::matmul(a_aten, b_aten);
  EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2))
      << "Output: " << output << " Expected: " << expected_output;
}

TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_K) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* a = makeContigTensor(2);
  TensorView* b = makeContigTensor(2);
  TensorView* c = matmul(a, b);
  fusion->addInput(a);
  fusion->addInput(b);
  fusion->addOutput(c);
  c->axis(-1)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

// We only support Stream parallel type on ops that support pre-allocated
// output, which means they need a special handle in HostIrEvaluator and they
// need to be lowered as a Host Ir Op in the TopLevelExpression, no a
// PostOnStream(HostUnit(.)) See HostIrLower::isLoweredAsStandaloneHostOp and
// the test HirLowerStreamTest.DoNotSupportPostOnStream
TEST_F(MultiDeviceExecutorLowerStreamTest, DoNotSupportPostOnStream) {
  auto fusion = std::make_unique<Fusion>();
  FusionGuard fg(fusion.get());
  TensorView* tv0 = makeContigTensor(2);
  TensorView* tv1 =
      abs(tv0); // arbitrary example of an unsupported op. There is no deep
                // reason why we not support it -- if needed we could widen the
                // support. But I wanna make sure that an unsupported op do not
                // silently fails
  fusion->addInput(tv0);
  fusion->addOutput(tv1);
  tv1->axis(0)->parallelize(ParallelType::Stream);

  EXPECT_ANY_THROW(
      MultiDeviceExecutor(std::move(fusion), Communicator::getInstance()));
}

} // namespace nvfuser
Code Clarity

The addition of the stream parallel type pass in the HostIrLower::lower function should be clearly documented. Ensure that the code is easy to understand and that the purpose of each step is well-explained.

for (auto input : staged_fusion->inputs()) {
  hic->addInput(ir_cloner.clone(input));
}
for (auto output : staged_fusion->outputs()) {
  hic->addOutput(ir_cloner.clone(output));
}

for (auto tv : hic->allTvs()) {
  // set all host tensors to global memory type. This must be the case by
  // definition of a host tensor, and setting the memory type to global is
  // also required to avoid Allocate HIR nodes to throw
  tv->setMemoryType(MemoryType::Global);
}

std::vector<Expr*> new_top_level_exprs;

@samnordmann samnordmann changed the title [Host Ir] stream lowering, first milestone [Host Ir] stream lowering, first milestone: single device fusion Apr 15, 2025
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from 02d494e to b6c54f2 Compare April 16, 2025 10:50
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from b8fa021 to 165bd1b Compare April 16, 2025 15:58
@samnordmann samnordmann force-pushed the host_irs/stream_lowering/single_device_fusions branch from ec1c9f5 to a50b53c Compare April 23, 2025 22:20
@samnordmann samnordmann marked this pull request as ready for review April 27, 2025 14:29
@samnordmann samnordmann requested a review from wujingyue April 27, 2025 14:31
@samnordmann
Copy link
Collaborator Author

@wujingyue can you approve this PR? It is the same branch as in #4147 that you already approved, but based on main and not another branch of mine (somehow, automatically changing the base of the PR errored)

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann
Copy link
Collaborator Author

!test

@samnordmann samnordmann merged commit 287ede2 into main Apr 28, 2025
60 checks passed
@samnordmann samnordmann deleted the host_irs/stream_lowering/single_device_fusions branch April 28, 2025 09:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants