diff --git a/CMakeLists.txt b/CMakeLists.txt index 0f14bc07f07..b09c6adb22f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -213,6 +213,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/preseg_passes/remove_empty.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/reorder_sharded_axis.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/segment_inplace_update.cpp + ${NVFUSER_SRCS_DIR}/host_ir/pass/stream_parallel_type.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_no_reduction_matmul_to_mul_squeeze.cpp ${NVFUSER_SRCS_DIR}/preseg_passes/translate_repeat_to_expand.cpp ${NVFUSER_SRCS_DIR}/rng.cpp @@ -739,6 +740,7 @@ if(BUILD_TEST) list(APPEND HOSTIR_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_host_irs.cpp ${NVFUSER_ROOT}/tests/cpp/test_host_ir_integration.cpp + ${NVFUSER_ROOT}/tests/cpp/test_host_ir_stream_lowering.cpp ) add_test(test_host_ir "${HOSTIR_TEST_SRCS}" "") list(APPEND TEST_BINARIES test_host_ir) diff --git a/csrc/host_ir/container.cpp b/csrc/host_ir/container.cpp index 83e668770fc..9fdcfa376a6 100644 --- a/csrc/host_ir/container.cpp +++ b/csrc/host_ir/container.cpp @@ -35,11 +35,13 @@ Stream* HostIrContainer::getDefaultStream() { std::ostream& HostIrContainer::print(std::ostream& os) const { IrMathPrinter op_exprs(os); op_exprs.handle(this); - os << "Aliases:{"; - for (const auto& alias : alias_) { - os << "\n " << alias.first << " -> " << alias.second; + if (alias_.size() > 0) { + os << "Aliases:{"; + for (const auto& alias : alias_) { + os << "\n " << alias.first << " -> " << alias.second; + } + os << "\n}\n"; } - os << "\n}\n"; return os; } diff --git a/csrc/host_ir/executor.h b/csrc/host_ir/executor.h index 7894245de75..3f147b7801b 100644 --- a/csrc/host_ir/executor.h +++ b/csrc/host_ir/executor.h @@ -97,6 +97,10 @@ class HostIrEvaluator final : public OptOutDispatch { return container_->outputs(); } + auto* container() const { + return container_.get(); + } + std::ostream& print(std::ostream& os) const { return container_->print(os); }; diff --git a/csrc/host_ir/lower.cpp b/csrc/host_ir/lower.cpp index fd14096b190..ca9bb80ae4e 100644 --- a/csrc/host_ir/lower.cpp +++ b/csrc/host_ir/lower.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -735,6 +736,13 @@ std::unique_ptr HostIrLower::lower( hic->addOutput(ir_cloner.clone(output)); } + for (auto tv : hic->allTvs()) { + // set all host tensors to global memory type. This must be the case by + // definition of a host tensor, and setting the memory type to global is + // also required to avoid Allocate HIR nodes to throw + tv->setMemoryType(MemoryType::Global); + } + std::vector new_top_level_exprs; for (auto top_level_expr : hic->topLevelExprs()) { if (!isResharding(top_level_expr)) { @@ -761,6 +769,8 @@ std::unique_ptr HostIrLower::lower( } hic->resetTopLevelExprs(new_top_level_exprs); + preseg_passes::OptimizationPass::runPass(hic.get()); + return hic; } diff --git a/csrc/host_ir/pass/stream_parallel_type.cpp b/csrc/host_ir/pass/stream_parallel_type.cpp new file mode 100644 index 00000000000..d7bfa0f090a --- /dev/null +++ b/csrc/host_ir/pass/stream_parallel_type.cpp @@ -0,0 +1,440 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser::hir { + +namespace { + +// Finds the stream axis in a tensor's domain. There should be at most one +// stream axis. +IterDomain* getStreamAxis(const std::vector& domain) { + IterDomain* ret = nullptr; + for (auto id : domain) { + if (id->getParallelType() == ParallelType::Stream) { + NVF_CHECK( + ret == nullptr, + "Expected at most one stream axis in the domain, but found ", + id, + " and ", + ret); + ret = id; + } + } + return ret; +} + +// Validates that a stream axis is valid in a tensor +void validateStreamAxis(IterDomain* stream_axis, const TensorView* tv) { + // Find the stream axis in the logical domain + auto it_logical_stream_axis = std::find( + tv->getLogicalDomain().begin(), + tv->getLogicalDomain().end(), + stream_axis); + + // Verify stream axis is not split/merged + NVF_ERROR( + it_logical_stream_axis != tv->getLogicalDomain().end(), + "Cannot stream parallelize on a split/merge axis ", + stream_axis); + + // Verify stream axis is an iteration or broadcast axis + NVF_CHECK( + stream_axis->getIterType() == IterType::Iteration || + stream_axis->getIterType() == IterType::Broadcast, + "Stream axis ", + stream_axis, + " should be an iteration or broadcast axis."); +} + +// Checks if two iteration domains are mapped in the ID model +bool areIdsMapped(const IdModel& id_model, IterDomain* id1, IterDomain* id2) { + return id_model.idGraph(IdMappingMode::BROADCAST) + .disjointValSets() + .strictAreMapped(id1, id2); +} + +// Determines if a stream-parallel for-loop can be merged with the previous one +bool canMergeWithPreviousForLoop( + const std::vector& new_top_level_exprs, + IterDomain* stream_axis, + const IdModel& id_model) { + return !new_top_level_exprs.empty() && + new_top_level_exprs.back()->isA() && + areIdsMapped( + id_model, + stream_axis, + new_top_level_exprs.back()->as()->iterDomain()); +} + +// Finds where a stream axis appears in a tensor's logical domain +int64_t findStreamAxisIndex( + const TensorView* tv, + IterDomain* stream_axis, + const IdModel& id_model) { + int64_t stream_id_logical_index = -1; + for (auto id : tv->getLoopDomain()) { + if (areIdsMapped(id_model, stream_axis, id)) { + // Verify only one stream axis exists + NVF_CHECK( + stream_id_logical_index == -1, + "Expected at most one axis mapping to the stream axis ", + stream_axis, + " in the tensor ", + tv, + " loop's domain ", + tv->getLoopDomain()); + + // Find stream axis in logical domain + auto it_stream_id_logical = std::find( + tv->getLogicalDomain().begin(), tv->getLogicalDomain().end(), id); + NVF_CHECK( + it_stream_id_logical != tv->getLogicalDomain().end(), + "Expected to find ", + id, + " in ", + tv, + "'s logical domain ", + tv->getLogicalDomain()); + stream_id_logical_index = + std::distance(tv->getLogicalDomain().begin(), it_stream_id_logical); + } + } + return stream_id_logical_index; +} + +// Cache for tensor slicing operations in stream parallelization. +// This cache stores previously created sliced versions of tensors to avoid +// redundant slicing operations. A sliced tensor is created by removing a +// specific axis (stream axis) from the tensor's domain and creating a new +// tensor that represents a slice of the original tensor at a given index. +// The cache key is a tuple of (original tensor, axis index to remove, slice +// index). +struct TensorSlicingCache { + // Type aliases + using Key = std::tuple; + + // Custom hash function for the tuple used as cache key + struct Hash { + size_t operator()(const Key& t) const { + auto [tv, idx, val] = t; + return std::hash{}(tv) ^ std::hash{}(idx) ^ + std::hash{}(val); + } + }; + + // Map type for storing cached sliced tensors + using Map = std::unordered_map; + + // Get the expr producing the indexed version of a tensor. If the expr already + // exists in the cache, returns the cached version. Otherwise, creates a new + // expr, producing a tensor "selected" on its dimension `stream_axis_index` at + // index `index`. Returns a pair of (expr, is_new) where is_new indicates + // whether the expr was newly created. + std::pair get( + TensorView* tensor, + int64_t stream_axis_index, + Val* index) { + auto key = std::make_tuple(tensor, stream_axis_index, index); + auto it = cache_.find(key); + if (it != cache_.end()) { + return {it->second, false}; + } + + auto dom = tensor->getLogicalDomain(); + std::vector new_root; + new_root.reserve(dom.size() - 1); + + for (auto i : arange((int64_t)dom.size())) { + if (i != stream_axis_index) { + new_root.emplace_back(dom[i]->cloneWithoutRFactor()); + } + } + + auto td = IrBuilder::create( + new_root, TensorDomain::getContiguityFilledWith(new_root, true)); + auto out = IrBuilder::create(td, *tensor->getDataType()); + auto result = IrBuilder::create( + tensor, out, stream_axis_index, index); + + cache_[key] = result; + return {result, true}; + } + + private: + Map cache_; // Storage for cached sliced tensors +}; + +// Step 1: Group expressions into stream-parallel regions +std::vector groupStreamParallelRegions( + const std::vector& top_level_exprs, + const IdModel& id_model) { + std::vector new_top_level_exprs; + + for (auto* expr : top_level_exprs) { + // Skip expressions with no outputs + if (expr->outputs().size() == 0) { + new_top_level_exprs.push_back(expr); + continue; + } + + // Each expression should have exactly one output + NVF_CHECK( + expr->outputs().size() == 1, + "Each expr should have at most one output."); + + // Get the output tensor and check for stream parallelization + TensorView* output = expr->output(0)->as(); + IterDomain* stream_axis = getStreamAxis(output->getLoopDomain()); + + // If no stream axis found, keep the expression as is + if (stream_axis == nullptr) { + new_top_level_exprs.push_back(expr); + continue; + } + + // Verify that the expression can be handled as a standalone host operation + NVF_ERROR( + HostIrLower::isLowerableAsStandaloneHostOp(expr), + "Stream parallel type not supported for expr ", + expr); + + // Validate stream axis + validateStreamAxis(stream_axis, output); + + // Check if we can merge this expression with the previous for-loop + if (canMergeWithPreviousForLoop( + new_top_level_exprs, stream_axis, id_model)) { + // Merge with existing for-loop by adding the expression to its body + new_top_level_exprs.back()->as()->body().push_back(expr); + } else { + // Create a new for-loop for stream parallelization + auto* for_loop = IrBuilder::create( + stream_axis, + /*index=*/NamedScalar::getParallelIndex(ParallelType::Stream), + /*start=*/FusionGuard::getCurFusion()->zeroVal(), + /*stop=*/stream_axis->extent(), + /*step=*/FusionGuard::getCurFusion()->oneVal(), + /*vectorize=*/false, + /*vectorize_shift=*/nullptr, + /*unroll_required=*/false, + CircularBufferLoopStage::NotApplicable, + /*circular_buffer_loop_stage_depth=*/0); + // Add the expression to the new for-loop's body + for_loop->body().push_back(expr); + new_top_level_exprs.push_back(for_loop); + } + } + + return new_top_level_exprs; +} + +// Helper function to add allocations for tensors that need them +std::vector addTensorAllocations( + std::vector top_level_exprs, + const IdModel& id_model) { + std::vector new_top_level_exprs; + + for (auto* expr : top_level_exprs) { + if (expr->isA()) { + // add allocations for tensors produced in the loop that have a stream + // axes + auto* for_loop = expr->as(); + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + if (findStreamAxisIndex(output, for_loop->iterDomain(), id_model) != + -1) { + new_top_level_exprs.push_back( + IrBuilder::create(output, MemoryType::Global)); + } + } + } + } + new_top_level_exprs.push_back(expr); + } + + return new_top_level_exprs; +} + +// Step 3: Process for-loop bodies by slicing tensors +std::vector processForLoopBodies( + std::vector top_level_exprs, + const IdModel& id_model) { + TensorSlicingCache tensor_slicing_cache; + + for (auto* expr : top_level_exprs) { + if (!expr->isA()) { + continue; + } + + auto* for_loop = expr->as(); + std::vector new_loop_body; + + // Lambda to process a tensor in a for-loop body + auto processTensor = [&](Expr*& expr, TensorView* tensor) { + if (auto stream_idx = + findStreamAxisIndex(tensor, for_loop->iterDomain(), id_model); + stream_idx != -1) { + auto [slicing, is_new] = + tensor_slicing_cache.get(tensor, stream_idx, for_loop->index()); + if (is_new) { + new_loop_body.push_back(slicing); + } + expr = ir_utils::replaceValInExprInputs(expr, tensor, slicing->out()); + if (expr->outputs().size() > 0 && expr->outputs()[0] == tensor) { + expr = + ir_utils::transferDefinitionToNewOutputs(expr, {slicing->out()}); + } + } + }; + + for (auto* body_expr : for_loop->body().exprs()) { + for (auto* input : + ir_utils::filterByType(body_expr->inputs())) { + processTensor(body_expr, input); + } + for (auto* output : + ir_utils::filterByType(body_expr->outputs())) { + processTensor(body_expr, output); + } + new_loop_body.push_back(body_expr); + } + + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + } + + return top_level_exprs; +} + +// Step 4: Add stream management and synchronization +std::vector addStreamManagement(std::vector top_level_exprs) { + // Process each top-level expression + for (auto* top_level_expr : top_level_exprs) { + // Skip non-for-loop expressions + if (!top_level_expr->isA()) { + continue; + } + + auto* for_loop = top_level_expr->as(); + std::vector new_loop_body; + + // Get the current stream before entering the loop + auto* get_current_stream = IrBuilder::create(); + hir::Stream* original_stream = get_current_stream->stream(); + new_loop_body.push_back(get_current_stream); + + // Set up a new stream for this iteration based on the loop index + auto* number_of_streams = + IrBuilder::create("numberOfStreams", DataType::Int); + auto* stream_index = mod(for_loop->index(), number_of_streams); + auto* stream = IrBuilder::create(stream_index); + auto* set_stream = IrBuilder::create(stream); + new_loop_body.push_back(set_stream); + + // Synchronize with the original stream before starting computation + auto* initial_sync_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(initial_sync_stream); + + // Add all the expressions to the loop body + for (auto* expr : for_loop->body().exprs()) { + new_loop_body.push_back(expr); + } + + // Restore the original stream and synchronize with the iteration's stream + auto* set_back_original_stream = + IrBuilder::create(original_stream); + new_loop_body.push_back(set_back_original_stream); + auto* sync_stream = IrBuilder::create(stream); + new_loop_body.push_back(sync_stream); + + // Update the for-loop body with the new expressions + for_loop->body().clear(); + for (auto* expr : new_loop_body) { + for_loop->body().push_back(expr); + } + } + + return top_level_exprs; +} + +} // anonymous namespace + +// StreamParallelType pass implementation. +// This pass handles stream parallelization of operations in a fusion. +// It works by: +// 1. Identifying stream-parallelized axes in tensor operations +// 2. Grouping compatible operations into stream-parallel for-loops +// 3. Setting up proper stream synchronization and management +// 4. Adding allocations for tensors that need them +// The pass ensures that: +// - Input tensors don't have stream axes +// - Only one stream axis exists per tensor +// - Stream axes are properly synchronized +// - Operations are correctly grouped into stream-parallel regions +// - The resulting HostIrContainer's top level expression is valid for execution +// and does not contain any stream axes +// +// TODO: Here, we assume that the fusion input is a HostIrContainer and use the +// linear structure of the HostIrContainer::topLevelExpr to greedily merge the +// adjacent compatible stream for-loop bodies. Ideally we should look at the dag +// and use the segmenter. +void StreamParallelType::runPass(Fusion* fusion) { + // Verify that input tensors don't have stream axes + NVF_CHECK( + std::all_of( + fusion->inputs().begin(), + fusion->inputs().end(), + [](Val* input) { + auto input_tv = dynamic_cast(input); + return input_tv == nullptr || + getStreamAxis(input_tv->getLoopDomain()) == nullptr; + }), + "Expected no stream axis in the TensorView inputs."); + + // Set up the fusion environment and build the ID model + FusionGuard fg(fusion); + hir::HostIrContainer* hic = dynamic_cast(fusion); + NVF_CHECK(hic, "Expected HostIrContainer"); + + IdModel id_model(fusion); + id_model.buildBroadcastGraph(); + + // Step 1: Group expressions into stream-parallel regions + std::vector top_level_exprs = + groupStreamParallelRegions(hic->topLevelExprs(), id_model); + + // Step 2: Add allocations for tensors that need them + top_level_exprs = addTensorAllocations(std::move(top_level_exprs), id_model); + + // Step 3: Process for-loop bodies by slicing tensors + top_level_exprs = processForLoopBodies(std::move(top_level_exprs), id_model); + + // Step 4: Add stream management and synchronization + top_level_exprs = addStreamManagement(std::move(top_level_exprs)); + + // Update the container's top-level expressions + hic->resetTopLevelExprs(top_level_exprs); +} + +} // namespace nvfuser::hir diff --git a/csrc/host_ir/pass/stream_parallel_type.h b/csrc/host_ir/pass/stream_parallel_type.h new file mode 100644 index 00000000000..8b5f138ad7e --- /dev/null +++ b/csrc/host_ir/pass/stream_parallel_type.h @@ -0,0 +1,36 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser::hir { + +// A pass used in HostIrLower that takes a HostIrContainer as input, reads the +// TensorView's ParallelType::Stream, and modify the the HostIrContainer's top +// level expressions with the corresponding Host For Loops, which bodies contain +// stream assignement, selecting on tensor's axis, and the exprs on those sliced +// tensors. After this pass, the ParallelType::Stream is removed from the +// TensorView's axis. +// +// An illustration of the pass can be found in the tests +// `test_host_ir_stream_lowering.cpp` +// with the option `NVFUSER_DUMP=host_ir`. +class StreamParallelType + : public preseg_passes::OptimizationPass { + friend class preseg_passes::OptimizationPass; + + protected: + static void runPass(Fusion* fusion); + static constexpr std::string_view name() { + return "StreamParallelType"; + } +}; + +} // namespace nvfuser::hir diff --git a/csrc/ir/internal_nodes.h b/csrc/ir/internal_nodes.h index a77545d63cd..066823e42c9 100644 --- a/csrc/ir/internal_nodes.h +++ b/csrc/ir/internal_nodes.h @@ -2533,6 +2533,10 @@ class ForLoop final : public Expr { return input(0); } + IterDomain* iterDomain() const { + return input(1)->as(); + } + Val* indexOrStartIfTrivial() const { return isTrivial() ? start() : index(); } diff --git a/csrc/multidevice/executor.h b/csrc/multidevice/executor.h index c1cc3e31cfe..7dd08a87f0a 100644 --- a/csrc/multidevice/executor.h +++ b/csrc/multidevice/executor.h @@ -103,6 +103,10 @@ class MultiDeviceExecutor { return host_ir_executor_->getFusionExecutorCaches(); }; + auto* hostIrEvaluator() const { + return host_ir_executor_.get(); + } + private: // holds the Communicator to be used for execution Communicator& comm_; diff --git a/csrc/preseg_passes/optimization_pass.h b/csrc/preseg_passes/optimization_pass.h index 53d8a8acd3c..359a4a42742 100644 --- a/csrc/preseg_passes/optimization_pass.h +++ b/csrc/preseg_passes/optimization_pass.h @@ -18,8 +18,6 @@ namespace nvfuser::preseg_passes { -using FusionPass = std::function; - //! [experimental API] //! Base class to unify optimization pass APIs. //! OptimizationPass can be turned on/off programmatically with the `setEnabled` diff --git a/csrc/type.cpp b/csrc/type.cpp index c3664fc6805..31163c7ced7 100644 --- a/csrc/type.cpp +++ b/csrc/type.cpp @@ -729,7 +729,7 @@ static const char* parallel_type2string(ParallelType t) { case ParallelType::TIDx: return "threadIdx.x"; case ParallelType::Stream: - return "Stream"; + return "StreamIdx"; case ParallelType::Vectorize: return "V"; case ParallelType::Unroll: diff --git a/python/python_frontend/fusion_definition.cpp b/python/python_frontend/fusion_definition.cpp index c48abc9dbdc..b77947f1415 100644 --- a/python/python_frontend/fusion_definition.cpp +++ b/python/python_frontend/fusion_definition.cpp @@ -7,6 +7,7 @@ // clang-format on #include #include +#include #include #include #include @@ -452,6 +453,10 @@ std::pair> FusionDefinition:: if (scheds->multi_device_executor == nullptr) { MultiDeviceExecutorParams params; params.lower.communicator_backend = backend_type_; + // Disable StreamParallelType pass temporarily as proper stream lowering + // gets implemented + preseg_passes::OptimizationPassGuard guard( + false); scheds->multi_device_executor = std::make_unique( std::make_unique(*scheds->preschedFusion()), Communicator::getInstance(), diff --git a/tests/cpp/test_host_ir_stream_lowering.cpp b/tests/cpp/test_host_ir_stream_lowering.cpp new file mode 100644 index 00000000000..b77df002bc6 --- /dev/null +++ b/tests/cpp/test_host_ir_stream_lowering.cpp @@ -0,0 +1,814 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2025-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace nvfuser { + +namespace hir { + +using HirLowerStreamTest = NVFuserTest; + +TEST_F(HirLowerStreamTest, InputsAreNotStreamParallelized) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv = makeContigTensor(2); + hic->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Split) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Merge) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, SingleSetOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleSetOpNonOutermost) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, SingleBinaryOp) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + hic->addInput(tv0); + hic->addInput(tv1); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + // std::unordered_map inputs = {{tv0, input}}; + auto output = hie.runWithInput({{tv0, tv0_input}, {tv1, tv1_input}})[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, TwoSetOps) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + hic->addInput(tv0); + hic->addOutput(tv2); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 3); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + hic->addInput(tv0); + hic->addOutput(tv3); + hic->pushBackTopLevelExprs(tv1->definition()); + hic->pushBackTopLevelExprs(tv2->definition()); + hic->pushBackTopLevelExprs(tv3->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv2->setMemoryType(MemoryType::Global); + tv3->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 5); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(4)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(HirLowerStreamTest, ReductionUnsupported) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +TEST_F(HirLowerStreamTest, Reduction) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + hic->addInput(tv0); + hic->addOutput(tv1); + hic->pushBackTopLevelExprs(tv1->definition()); + tv0->setMemoryType(MemoryType::Global); + tv1->setMemoryType(MemoryType::Global); + tv1->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = hie.runWithInput({{tv0, input}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_M) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, BatchedMatmul) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_N) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass::runPass(hic.get()); + + EXPECT_EQ(hic->topLevelExprs().size(), 2); + EXPECT_TRUE(hic->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(hic->topLevelExprs().at(1)->isA()); + + HostIrEvaluator hie(std::move(hic)); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = + hie.runWithInput({{a, a_aten}, {b, b_aten}})[0].as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(HirLowerStreamTest, Matmul_K) { + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + hic->addInput(a); + hic->addInput(b); + hic->addOutput(c); + hic->pushBackTopLevelExprs(c->definition()); + a->setMemoryType(MemoryType::Global); + b->setMemoryType(MemoryType::Global); + c->setMemoryType(MemoryType::Global); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +// We don's support PostOnStream because it does not support well pre-allocated +// outputs. There is no strong motivation to support PostOnStream +TEST_F(HirLowerStreamTest, DoNotSupportPostOnStream) { + const std::vector input_sizes = {4, 8, 32}; + const std::vector output_sizes = { + input_sizes.at(1), input_sizes.at(2)}; + + auto get_fusion = [input_sizes]() -> std::unique_ptr { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + auto tv0 = makeConcreteTensor(input_sizes); + auto tv1 = add(tv0, tv0); + auto tv2 = sum(tv1, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv2); + return fusion; + }; + + auto hic = std::make_unique(); + FusionGuard fg(hic.get()); + + auto host_unit = IrBuilder::create(get_fusion()); + + IrCloner ir_cloner(hic.get()); + TensorView* input = + ir_cloner.clone(host_unit->fusion_to_execute()->inputs().at(0)) + ->as(); + TensorView* output = + ir_cloner.clone(host_unit->fusion_to_execute()->outputs().at(0)) + ->as(); + + std::vector inputs = {input}; + std::vector outputs = {output}; + auto post_on_stream = + IrBuilder::create(host_unit, inputs, outputs); + + hic->pushBackTopLevelExprs(post_on_stream); + + hic->addInput(input); + hic->addOutput(output); + + output->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + preseg_passes::OptimizationPass::runPass(hic.get())); +} + +} // namespace hir + +using MultiDeviceExecutorLowerStreamTest = NVFuserTest; + +TEST_F(MultiDeviceExecutorLowerStreamTest, InputsAreNotStreamParallelized) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv = makeContigTensor(2); + fusion->addInput(tv); + tv->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Split) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->split(0, 2); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Merge) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->merge(0, 1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleSetOpNonOutermost) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, SingleBinaryOp) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, tv1); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv2); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + + at::Tensor tv0_input = at::rand({4, 4}, options); + at::Tensor tv1_input = at::rand({4, 4}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({tv0_input, tv1_input}))[0] + .as(); + auto expected_output = tv0_input + tv1_input; + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << "Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, TwoSetOps) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + fusion->addInput(tv0); + fusion->addOutput(tv2); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv2->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 3); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ThreeSetOpsWithDisjointsForLoops) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = set(tv0); + TensorView* tv2 = set(tv1); + TensorView* tv3 = set(tv2); + fusion->addInput(tv0); + fusion->addOutput(tv3); + tv1->axis(0)->parallelize(ParallelType::Stream); + tv3->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 5); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(2)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(3)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(4)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + EXPECT_TRUE(output.equal(input)) + << "Output: " << output << " Expected: " << input; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, ReductionUnsupported) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = sum(tv0, {0}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Reduction) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(3); + TensorView* tv1 = sum(tv0, {2}); + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor input = at::rand({4, 8, 2}, options); + auto output = + executor.runWithInput(KernelArgumentHolder({input}))[0].as(); + + torch::cuda::synchronize(); + auto expected_output = input.sum(2); + EXPECT_TRUE(output.equal(expected_output)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_M) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, BatchedMatmul) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(3); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(0)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t B = 16, M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({B, M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_N) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(1)->parallelize(ParallelType::Stream); + + MultiDeviceExecutor executor(std::move(fusion), Communicator::getInstance()); + + hir::HostIrContainer* container = executor.hostIrEvaluator()->container(); + EXPECT_EQ(container->topLevelExprs().size(), 2); + EXPECT_TRUE(container->topLevelExprs().at(0)->isA()); + EXPECT_TRUE(container->topLevelExprs().at(1)->isA()); + + constexpr int64_t M = 8, K = 4, N = 2; + auto options = at::TensorOptions().device(at::kCUDA, 0); + at::Tensor a_aten = at::rand({M, K}, options); + at::Tensor b_aten = at::rand({K, N}, options); + auto output = executor.runWithInput(KernelArgumentHolder({a_aten, b_aten}))[0] + .as(); + + torch::cuda::synchronize(); + auto expected_output = at::matmul(a_aten, b_aten); + EXPECT_TRUE(torch::allclose(output, expected_output, 1e-2, 1e-2)) + << "Output: " << output << " Expected: " << expected_output; +} + +TEST_F(MultiDeviceExecutorLowerStreamTest, Matmul_K) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* a = makeContigTensor(2); + TensorView* b = makeContigTensor(2); + TensorView* c = matmul(a, b); + fusion->addInput(a); + fusion->addInput(b); + fusion->addOutput(c); + c->axis(-1)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +// We only support Stream parallel type on ops that support pre-allocated +// output, which means they need a special handle in HostIrEvaluator and they +// need to be lowered as a Host Ir Op in the TopLevelExpression, no a +// PostOnStream(HostUnit(.)) See HostIrLower::isLoweredAsStandaloneHostOp and +// the test HirLowerStreamTest.DoNotSupportPostOnStream +TEST_F(MultiDeviceExecutorLowerStreamTest, DoNotSupportPostOnStream) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = + abs(tv0); // arbitrary example of an unsupported op. There is no deep + // reason why we not support it -- if needed we could widen the + // support. But I wanna make sure that an unsupported op do not + // silently fails + fusion->addInput(tv0); + fusion->addOutput(tv1); + tv1->axis(0)->parallelize(ParallelType::Stream); + + EXPECT_ANY_THROW( + MultiDeviceExecutor(std::move(fusion), Communicator::getInstance())); +} + +} // namespace nvfuser diff --git a/tests/cpp/test_multidevice_host_ir.cpp b/tests/cpp/test_multidevice_host_ir.cpp index 88286d6e4c0..db53f7f114d 100644 --- a/tests/cpp/test_multidevice_host_ir.cpp +++ b/tests/cpp/test_multidevice_host_ir.cpp @@ -9,6 +9,7 @@ #include #include #include +#include #include #include #include @@ -362,6 +363,10 @@ TEST_F(P2PCommHostIrTest, CoalescedRingPairwiseExchange) { using OverlapDistributedMatmulTest = MultiDeviceTest; TEST_F(OverlapDistributedMatmulTest, AG_matmul) { + // Disable StreamParallelType pass temporarily as proper stream lowering gets + // implemented + preseg_passes::OptimizationPassGuard guard(false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024; @@ -417,6 +422,9 @@ TEST_F(OverlapDistributedMatmulTest, AG_matmul) { } TEST_F(OverlapDistributedMatmulTest, AG_linear) { + // Disable StreamParallelType pass tempor + preseg_passes::OptimizationPassGuard guard(false); + constexpr int64_t M = 32768; constexpr int64_t K = 32768; constexpr int64_t N = 1024;