From 8997612f01ea753914f0de607c19cc26922f9f7a Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 9 Oct 2025 12:07:33 -0700 Subject: [PATCH 1/6] wip --- csrc/multidevice/utils.cpp | 12 ++- csrc/multidevice/utils.h | 14 ++-- csrc/preseg_passes/propagate_shardings.cpp | 94 ++++++++-------------- tests/cpp/test_stream.cpp | 45 +++++++++++ 4 files changed, 97 insertions(+), 68 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index 000821097ea..ad06d0644a5 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,6 +121,8 @@ std::unordered_map mapIterDomainToTensorAxis( return id_to_axis; } +} // namespace + // Finds the logical IterDomain that transitively produces `id` and returns its // tensor axis. Returns -1 for reduction dimensions because they don't // correspond to any tensor axis. @@ -193,8 +195,6 @@ int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id) { } } -} // namespace - int64_t getShardedLogicalAxis( const TensorView* tv, const ParallelType parallel_type) { @@ -208,7 +208,8 @@ int64_t getShardedLogicalAxis( IterDomain* getShardedIterDomain( const TensorView* tv, - const ParallelType parallel_type) { + const ParallelType parallel_type, + const std::vector& domain) { // The allocation domain for multidevice TensorViews is set during // presegmentation, which is after concretization. This exposes a issue: // allocation domain is not set for fusion inputs before presegmentation and @@ -218,7 +219,10 @@ IterDomain* getShardedIterDomain( // same DID parallelization. For ParalleType::Stream, fusion inputs will // always be fully allocated, and segment inputs/outputs may be partially / // fully allocated which can be inferred from its allocation domain. - const std::vector& domain = [&]() { + const std::vector& selected_domain = [&]() { + if (!domain.empty()) { + return domain; + } if (parallel_type == ParallelType::Stream) { return tv->getMaybeAllocationDomain(); } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index f609fec38fc..b98b95bd26d 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -109,13 +109,17 @@ void unshard(TensorView*); // extent if that IterDomain is sharded. int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); -// Returns the IterDomain that's parallelized on `parallel_type`. If it's not -// found, returns nullptr. `parallel_type` decides which domain to look at. -// ParallelType::Stream looks at the allocation domain and DIDs look at the loop -// domain. Refer to the implementation for the reason. +int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id); + +// Returns the IterDomain that's parallelized on `parallel_type` in the given +// domain. If it's not found, returns nullptr. If no domain is given, +// `parallel_type` decides which domain to look at. ParallelType::Stream looks +// at the allocation domain and DIDs look at the loop domain. Refer to the +// implementation for the reason. IterDomain* getShardedIterDomain( const TensorView* tv, - ParallelType parallel_type); + ParallelType parallel_type, + const std::vector& domain = {}); // Shards the input tensor along `axis`. How the tensor gets sliced along `axis` // is determined by `mesh` and `device_id`. Returns the sharded tensor. diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index de6fd81a3f2..e40ce476097 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -41,61 +41,57 @@ bool isSplitDivisible(IterDomain* id, Split* ref_split) { return id_extent % split_factor == 0; } +// Sort the given tvs by the number of device/stream dimensions in descending +// order. Break ties by rank of device mesh. template -std::vector filterTvsWithMesh(const Range& tvs) { - std::vector tvs_with_mesh; - std::copy_if( - tvs.begin(), - tvs.end(), - std::back_inserter(tvs_with_mesh), - [](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); }); - return tvs_with_mesh; -} +std::vector sortTvsByParallelDims(const Range& tvs) { + auto num_parallel_dims = [](TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + [](IterDomain* id) { + return !id->isReduction() && (id->isStream() || id->isDeviceDim()); + }); + }; -// Sort the given tvs by the number of device dimensions in descending order. -// Break ties by the total number of dimensions. -// Only includes TensorViews that have a device mesh. -template -std::vector sortTvsByDeviceDims(const Range& tvs) { - // Filter out TVs without a device mesh - std::vector tvs_with_mesh = filterTvsWithMesh(tvs); - - // Then sort the filtered TVs - std::stable_sort( - tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { - int64_t a_device_dims = numDeviceDims(a); - int64_t b_device_dims = numDeviceDims(b); - if (a_device_dims != b_device_dims) { - return a_device_dims > b_device_dims; - } - // Break ties by rank of device mesh. - return a->getDeviceMesh().rank() > b->getDeviceMesh().rank(); - }); + std::vector tvs_vec(tvs.begin(), tvs.end()); + + std::ranges::stable_sort(tvs_vec, [&num_parallel_dims](auto a, auto b) { + int64_t a_device_dims = num_parallel_dims(a); + int64_t b_device_dims = num_parallel_dims(b); + if (a_device_dims != b_device_dims) { + return a_device_dims > b_device_dims; + } + // Break ties by rank of device mesh. + return a->getDeviceMesh().rank() > b->getDeviceMesh().rank(); + }); - return tvs_with_mesh; + return tvs_vec; } // Order the inputs of the expression based on their priority. // For linear op, we use weights and bias before input. // For matmul op, we use weights before input. -// For other ops, we sort the inputs by the number of device dimensions in -// descending order. +// For other ops, we sort the inputs by the number of device/stream dimensions +// in descending order. std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); if (LinearOp* linear_op = dynamic_cast(expr)) { // Use weights and bias before input. - return filterTvsWithMesh(std::vector( - {linear_op->inB(), linear_op->bias(), linear_op->inA()})); + if (linear_op->hasBias()) { + return {linear_op->inB(), linear_op->bias(), linear_op->inA()}; + } else { + return {linear_op->inB(), linear_op->inA()}; + } } if (MatmulOp* matmul_op = dynamic_cast(expr)) { // Use weights before input. - return filterTvsWithMesh( - std::vector({matmul_op->inB(), matmul_op->inA()})); + return {matmul_op->inB(), matmul_op->inA()}; } - // Sort inputs by number of device dimensions in descending order - std::vector sorted_inputs = sortTvsByDeviceDims(inputs); + // Sort inputs by number of device/stream dimensions in descending order + std::vector sorted_inputs = sortTvsByParallelDims(inputs); return sorted_inputs; } @@ -416,11 +412,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { NVF_ERROR(ref_input != nullptr); - NVF_ERROR( - ref_input->hasDeviceMesh(), - "Reference input ", - ref_input, - " has no device mesh."); // Consider out [M, N] = linear (inp [M, K], weight (N, // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N @@ -455,24 +446,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // All outputs of an expression (Welford, SDPA) should be uniformly sharded. // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. - std::vector sorted_outputs = sortTvsByDeviceDims(outputs); - - if (sorted_outputs.empty()) { - // No output with a device mesh. - continue; - } - + std::vector sorted_outputs = sortTvsByParallelDims(outputs); TensorView* ref_output = sorted_outputs.front(); - NVF_ERROR( - ref_output != nullptr && ref_output->hasDeviceMesh(), - "Reference output ", - ref_output, - " has no device mesh."); - - // For fusion inputs, only check if they have a device mesh. We do not - // modify their sharding. For non-fusion inputs, we try to propagate - // shardings from the reference output for parallel types that are not - // already present. + for (auto* target : ir_utils::filterByType(expr->inputs())) { // Allow inputs to be stream parallelized for easier analysis. if (user_sharded_tvs.count(target) && !target->isFusionInput()) { diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index c28398a8eb4..44b3e6dc6e8 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -16,6 +16,7 @@ #include #include #include +#include #include namespace nvfuser { @@ -86,4 +87,48 @@ TEST_F(StreamTest, haveDifferentShardings) { EXPECT_TRUE(haveDifferentShardings(tv2, tv3, {ParallelType::Stream})); } +TEST_F(StreamTest, ForwardPropagation) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t s = 2; + + TensorView* in = makeContigTensor(2); + TensorView* w = makeContigTensor(2); + TensorView* out = matmul(in, w); + fusion->addInput(in); + fusion->addInput(w); + fusion->addOutput(out); + + w->outer_split(1, s); + w->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + EXPECT_THAT(out->axis(1), IsParallelized(ParallelType::Stream)); +} + +TEST_F(StreamTest, BackwardPropagation) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + constexpr int64_t s = 2; + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = makeContigTensor(2); + TensorView* tv2 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv3 = add(tv1, IrBuilder::create(1.0)); + TensorView* tv4 = add(tv2, tv3); + fusion->addInput(tv0); + fusion->addInput(tv1); + fusion->addOutput(tv4); + + w->outer_split(1, s); + w->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + EXPECT_THAT(out->axis(1), IsParallelized(ParallelType::Stream)); +} + } // namespace nvfuser From 00b829a033d22d5b7e8c8e609a504fe45dce23a1 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 9 Oct 2025 13:11:00 -0700 Subject: [PATCH 2/6] tests --- csrc/multidevice/utils.cpp | 12 ++++------ csrc/multidevice/utils.h | 14 ++++------- csrc/preseg_passes/propagate_shardings.cpp | 3 +++ tests/cpp/test_stream.cpp | 27 +++++++++++----------- 4 files changed, 25 insertions(+), 31 deletions(-) diff --git a/csrc/multidevice/utils.cpp b/csrc/multidevice/utils.cpp index ad06d0644a5..000821097ea 100644 --- a/csrc/multidevice/utils.cpp +++ b/csrc/multidevice/utils.cpp @@ -121,8 +121,6 @@ std::unordered_map mapIterDomainToTensorAxis( return id_to_axis; } -} // namespace - // Finds the logical IterDomain that transitively produces `id` and returns its // tensor axis. Returns -1 for reduction dimensions because they don't // correspond to any tensor axis. @@ -195,6 +193,8 @@ int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id) { } } +} // namespace + int64_t getShardedLogicalAxis( const TensorView* tv, const ParallelType parallel_type) { @@ -208,8 +208,7 @@ int64_t getShardedLogicalAxis( IterDomain* getShardedIterDomain( const TensorView* tv, - const ParallelType parallel_type, - const std::vector& domain) { + const ParallelType parallel_type) { // The allocation domain for multidevice TensorViews is set during // presegmentation, which is after concretization. This exposes a issue: // allocation domain is not set for fusion inputs before presegmentation and @@ -219,10 +218,7 @@ IterDomain* getShardedIterDomain( // same DID parallelization. For ParalleType::Stream, fusion inputs will // always be fully allocated, and segment inputs/outputs may be partially / // fully allocated which can be inferred from its allocation domain. - const std::vector& selected_domain = [&]() { - if (!domain.empty()) { - return domain; - } + const std::vector& domain = [&]() { if (parallel_type == ParallelType::Stream) { return tv->getMaybeAllocationDomain(); } diff --git a/csrc/multidevice/utils.h b/csrc/multidevice/utils.h index b98b95bd26d..f609fec38fc 100644 --- a/csrc/multidevice/utils.h +++ b/csrc/multidevice/utils.h @@ -109,17 +109,13 @@ void unshard(TensorView*); // extent if that IterDomain is sharded. int64_t getShardedLogicalAxis(const TensorView* tv, ParallelType parallel_type); -int64_t getProducingLogicalAxis(const TensorView* tv, IterDomain* id); - -// Returns the IterDomain that's parallelized on `parallel_type` in the given -// domain. If it's not found, returns nullptr. If no domain is given, -// `parallel_type` decides which domain to look at. ParallelType::Stream looks -// at the allocation domain and DIDs look at the loop domain. Refer to the -// implementation for the reason. +// Returns the IterDomain that's parallelized on `parallel_type`. If it's not +// found, returns nullptr. `parallel_type` decides which domain to look at. +// ParallelType::Stream looks at the allocation domain and DIDs look at the loop +// domain. Refer to the implementation for the reason. IterDomain* getShardedIterDomain( const TensorView* tv, - ParallelType parallel_type, - const std::vector& domain = {}); + ParallelType parallel_type); // Shards the input tensor along `axis`. How the tensor gets sliced along `axis` // is determined by `mesh` and `device_id`. Returns the sharded tensor. diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e40ce476097..e8418bc9631 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -443,6 +443,9 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. for (Expr* expr : exprs | std::views::reverse) { const auto& outputs = ir_utils::filterByType(expr->outputs()); + if (outputs.empty()) { + continue; + } // All outputs of an expression (Welford, SDPA) should be uniformly sharded. // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index 44b3e6dc6e8..72ed565d993 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -15,8 +15,8 @@ #include #include #include -#include -#include +#include +#include #include namespace nvfuser { @@ -91,7 +91,7 @@ TEST_F(StreamTest, ForwardPropagation) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - constexpr int64_t s = 2; + const int64_t s = 2; TensorView* in = makeContigTensor(2); TensorView* w = makeContigTensor(2); @@ -105,30 +105,29 @@ TEST_F(StreamTest, ForwardPropagation) { preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - EXPECT_THAT(out->axis(1), IsParallelized(ParallelType::Stream)); + EXPECT_TRUE(out->axis(1)->isStream()) << out; } TEST_F(StreamTest, BackwardPropagation) { auto fusion = std::make_unique(); FusionGuard fg(fusion.get()); - constexpr int64_t s = 2; + const int64_t s = 2; TensorView* tv0 = makeContigTensor(2); - TensorView* tv1 = makeContigTensor(2); - TensorView* tv2 = add(tv0, IrBuilder::create(1.0)); - TensorView* tv3 = add(tv1, IrBuilder::create(1.0)); - TensorView* tv4 = add(tv2, tv3); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = add(tv1, tv1); fusion->addInput(tv0); - fusion->addInput(tv1); - fusion->addOutput(tv4); + fusion->addOutput(tv2); - w->outer_split(1, s); - w->axis(1)->parallelize(ParallelType::Stream); + tv2->outer_split(0, s); + tv2->axis(0)->parallelize(ParallelType::Stream); preseg_passes::OptimizationPass< preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); - EXPECT_THAT(out->axis(1), IsParallelized(ParallelType::Stream)); + for (auto* tv : {tv0, tv1, tv2}) { + EXPECT_TRUE(tv->axis(0)->isStream()) << tv; + } } } // namespace nvfuser From 67fe3107c4de4d27fc724be49cb57a5ba2b43d95 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 9 Oct 2025 16:06:11 -0700 Subject: [PATCH 3/6] skip scatter --- csrc/preseg_passes/propagate_shardings.cpp | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index e8418bc9631..8c9e7a890ec 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -182,6 +182,17 @@ void transformLoopDomain( target->setDeviceMesh(ref->getDeviceMesh()); } + bool is_scatter_op = direction == PropagateDirection::kForward + ? target->definition()->isA() + : ref->definition()->isA(); + + if (is_scatter_op) { + // Scatter op output has a disjoint logical-to-loop domain. + // So we skip propagation. It is not clear to me device / stream + // parallelization would mean on scatter output. + return; + } + std::unordered_map ref2target = getRef2TargetMap(ref, target, direction); From 1024527b46536f17748e69f1acb3b0196a0fc3a1 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 9 Oct 2025 16:45:11 -0700 Subject: [PATCH 4/6] itertype bug --- csrc/ops/utils.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/ops/utils.cpp b/csrc/ops/utils.cpp index 939cc9234db..0f5fa995b65 100644 --- a/csrc/ops/utils.cpp +++ b/csrc/ops/utils.cpp @@ -385,6 +385,8 @@ IterDomain* newOutputIterDomain( extent_val = promoteSize(extent_val, id->extent()); if (iter_type.has_value()) { iter_type = promoteIterType(iter_type.value(), id->getIterType()); + } else if (id->isGatherScatter()) { + iter_type = IterType::Iteration; } else { iter_type = id->getIterType(); } From bdd09b8a57a3080ae12da57cffc830df2063aee0 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Thu, 9 Oct 2025 17:55:55 -0700 Subject: [PATCH 5/6] fix condition --- csrc/preseg_passes/propagate_shardings.cpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index 8c9e7a890ec..b5353d73a64 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -182,13 +182,12 @@ void transformLoopDomain( target->setDeviceMesh(ref->getDeviceMesh()); } - bool is_scatter_op = direction == PropagateDirection::kForward - ? target->definition()->isA() - : ref->definition()->isA(); - - if (is_scatter_op) { + // If either the ref or target are scatter op outputs, skip propagation + if (target->definition() != nullptr && + target->definition()->isA()) { // Scatter op output has a disjoint logical-to-loop domain. - // So we skip propagation. It is not clear to me device / stream + // So we skip propagation to avoid errors in the following code such as when + // setting the loop domain. It is not clear to me what device / stream // parallelization would mean on scatter output. return; } From 4a92c754c0bd3accfadc9e9e25bce976e26f83b3 Mon Sep 17 00:00:00 2001 From: Priya Mishra <26priya11@gmail.com> Date: Mon, 13 Oct 2025 13:00:02 -0700 Subject: [PATCH 6/6] review comments --- csrc/preseg_passes/propagate_shardings.cpp | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index b5353d73a64..463749b779d 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -57,13 +57,8 @@ std::vector sortTvsByParallelDims(const Range& tvs) { std::vector tvs_vec(tvs.begin(), tvs.end()); std::ranges::stable_sort(tvs_vec, [&num_parallel_dims](auto a, auto b) { - int64_t a_device_dims = num_parallel_dims(a); - int64_t b_device_dims = num_parallel_dims(b); - if (a_device_dims != b_device_dims) { - return a_device_dims > b_device_dims; - } - // Break ties by rank of device mesh. - return a->getDeviceMesh().rank() > b->getDeviceMesh().rank(); + return std::make_pair(num_parallel_dims(a), a->getDeviceMesh().rank()) > + std::make_pair(num_parallel_dims(b), b->getDeviceMesh().rank()); }); return tvs_vec; @@ -76,7 +71,7 @@ std::vector sortTvsByParallelDims(const Range& tvs) { // in descending order. std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); - if (LinearOp* linear_op = dynamic_cast(expr)) { + if (auto* linear_op = dynamic_cast(expr)) { // Use weights and bias before input. if (linear_op->hasBias()) { return {linear_op->inB(), linear_op->bias(), linear_op->inA()}; @@ -85,15 +80,13 @@ std::vector getOrderedReferenceInputs(Expr* expr) { } } - if (MatmulOp* matmul_op = dynamic_cast(expr)) { + if (auto* matmul_op = dynamic_cast(expr)) { // Use weights before input. return {matmul_op->inB(), matmul_op->inA()}; } // Sort inputs by number of device/stream dimensions in descending order - std::vector sorted_inputs = sortTvsByParallelDims(inputs); - - return sorted_inputs; + return sortTvsByParallelDims(inputs); } // Returns the set of parallel types not seen on the loop domain of the given