diff --git a/csrc/preseg_passes/propagate_shardings.cpp b/csrc/preseg_passes/propagate_shardings.cpp index de6fd81a3f2..463749b779d 100644 --- a/csrc/preseg_passes/propagate_shardings.cpp +++ b/csrc/preseg_passes/propagate_shardings.cpp @@ -41,63 +41,52 @@ bool isSplitDivisible(IterDomain* id, Split* ref_split) { return id_extent % split_factor == 0; } +// Sort the given tvs by the number of device/stream dimensions in descending +// order. Break ties by rank of device mesh. template -std::vector filterTvsWithMesh(const Range& tvs) { - std::vector tvs_with_mesh; - std::copy_if( - tvs.begin(), - tvs.end(), - std::back_inserter(tvs_with_mesh), - [](TensorView* tv) { return tv != nullptr && tv->hasDeviceMesh(); }); - return tvs_with_mesh; -} +std::vector sortTvsByParallelDims(const Range& tvs) { + auto num_parallel_dims = [](TensorView* tv) { + return std::count_if( + tv->getLoopDomain().begin(), + tv->getLoopDomain().end(), + [](IterDomain* id) { + return !id->isReduction() && (id->isStream() || id->isDeviceDim()); + }); + }; -// Sort the given tvs by the number of device dimensions in descending order. -// Break ties by the total number of dimensions. -// Only includes TensorViews that have a device mesh. -template -std::vector sortTvsByDeviceDims(const Range& tvs) { - // Filter out TVs without a device mesh - std::vector tvs_with_mesh = filterTvsWithMesh(tvs); - - // Then sort the filtered TVs - std::stable_sort( - tvs_with_mesh.begin(), tvs_with_mesh.end(), [](auto a, auto b) { - int64_t a_device_dims = numDeviceDims(a); - int64_t b_device_dims = numDeviceDims(b); - if (a_device_dims != b_device_dims) { - return a_device_dims > b_device_dims; - } - // Break ties by rank of device mesh. - return a->getDeviceMesh().rank() > b->getDeviceMesh().rank(); - }); + std::vector tvs_vec(tvs.begin(), tvs.end()); - return tvs_with_mesh; + std::ranges::stable_sort(tvs_vec, [&num_parallel_dims](auto a, auto b) { + return std::make_pair(num_parallel_dims(a), a->getDeviceMesh().rank()) > + std::make_pair(num_parallel_dims(b), b->getDeviceMesh().rank()); + }); + + return tvs_vec; } // Order the inputs of the expression based on their priority. // For linear op, we use weights and bias before input. // For matmul op, we use weights before input. -// For other ops, we sort the inputs by the number of device dimensions in -// descending order. +// For other ops, we sort the inputs by the number of device/stream dimensions +// in descending order. std::vector getOrderedReferenceInputs(Expr* expr) { const auto& inputs = ir_utils::filterByType(expr->inputs()); - if (LinearOp* linear_op = dynamic_cast(expr)) { + if (auto* linear_op = dynamic_cast(expr)) { // Use weights and bias before input. - return filterTvsWithMesh(std::vector( - {linear_op->inB(), linear_op->bias(), linear_op->inA()})); + if (linear_op->hasBias()) { + return {linear_op->inB(), linear_op->bias(), linear_op->inA()}; + } else { + return {linear_op->inB(), linear_op->inA()}; + } } - if (MatmulOp* matmul_op = dynamic_cast(expr)) { + if (auto* matmul_op = dynamic_cast(expr)) { // Use weights before input. - return filterTvsWithMesh( - std::vector({matmul_op->inB(), matmul_op->inA()})); + return {matmul_op->inB(), matmul_op->inA()}; } - // Sort inputs by number of device dimensions in descending order - std::vector sorted_inputs = sortTvsByDeviceDims(inputs); - - return sorted_inputs; + // Sort inputs by number of device/stream dimensions in descending order + return sortTvsByParallelDims(inputs); } // Returns the set of parallel types not seen on the loop domain of the given @@ -186,6 +175,16 @@ void transformLoopDomain( target->setDeviceMesh(ref->getDeviceMesh()); } + // If either the ref or target are scatter op outputs, skip propagation + if (target->definition() != nullptr && + target->definition()->isA()) { + // Scatter op output has a disjoint logical-to-loop domain. + // So we skip propagation to avoid errors in the following code such as when + // setting the loop domain. It is not clear to me what device / stream + // parallelization would mean on scatter output. + return; + } + std::unordered_map ref2target = getRef2TargetMap(ref, target, direction); @@ -416,11 +415,6 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // Propagate shardings from reference inputs in order. for (auto* ref_input : reference_inputs) { NVF_ERROR(ref_input != nullptr); - NVF_ERROR( - ref_input->hasDeviceMesh(), - "Reference input ", - ref_input, - " has no device mesh."); // Consider out [M, N] = linear (inp [M, K], weight (N, // K)) with inp sharded on M ([DIDx(d), M/d, K]) and weight sharded on N @@ -452,27 +446,15 @@ void PropagateShardingsPass::runPass(Fusion* fusion) { // inputs. See MultiDevicePresegPassesTest.ResidualAdd for an example. for (Expr* expr : exprs | std::views::reverse) { const auto& outputs = ir_utils::filterByType(expr->outputs()); + if (outputs.empty()) { + continue; + } // All outputs of an expression (Welford, SDPA) should be uniformly sharded. // We pick the most parallel output as the reference. // This is to avoid picking seed/offset tvs in SDPA. - std::vector sorted_outputs = sortTvsByDeviceDims(outputs); - - if (sorted_outputs.empty()) { - // No output with a device mesh. - continue; - } - + std::vector sorted_outputs = sortTvsByParallelDims(outputs); TensorView* ref_output = sorted_outputs.front(); - NVF_ERROR( - ref_output != nullptr && ref_output->hasDeviceMesh(), - "Reference output ", - ref_output, - " has no device mesh."); - - // For fusion inputs, only check if they have a device mesh. We do not - // modify their sharding. For non-fusion inputs, we try to propagate - // shardings from the reference output for parallel types that are not - // already present. + for (auto* target : ir_utils::filterByType(expr->inputs())) { // Allow inputs to be stream parallelized for easier analysis. if (user_sharded_tvs.count(target) && !target->isFusionInput()) { diff --git a/tests/cpp/test_stream.cpp b/tests/cpp/test_stream.cpp index c28398a8eb4..72ed565d993 100644 --- a/tests/cpp/test_stream.cpp +++ b/tests/cpp/test_stream.cpp @@ -15,7 +15,8 @@ #include #include #include -#include +#include +#include #include namespace nvfuser { @@ -86,4 +87,47 @@ TEST_F(StreamTest, haveDifferentShardings) { EXPECT_TRUE(haveDifferentShardings(tv2, tv3, {ParallelType::Stream})); } +TEST_F(StreamTest, ForwardPropagation) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int64_t s = 2; + + TensorView* in = makeContigTensor(2); + TensorView* w = makeContigTensor(2); + TensorView* out = matmul(in, w); + fusion->addInput(in); + fusion->addInput(w); + fusion->addOutput(out); + + w->outer_split(1, s); + w->axis(1)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + EXPECT_TRUE(out->axis(1)->isStream()) << out; +} + +TEST_F(StreamTest, BackwardPropagation) { + auto fusion = std::make_unique(); + FusionGuard fg(fusion.get()); + + const int64_t s = 2; + + TensorView* tv0 = makeContigTensor(2); + TensorView* tv1 = add(tv0, IrBuilder::create(1.0)); + TensorView* tv2 = add(tv1, tv1); + fusion->addInput(tv0); + fusion->addOutput(tv2); + + tv2->outer_split(0, s); + tv2->axis(0)->parallelize(ParallelType::Stream); + + preseg_passes::OptimizationPass< + preseg_passes::PropagateShardingsPass>::runPass(fusion.get()); + for (auto* tv : {tv0, tv1, tv2}) { + EXPECT_TRUE(tv->axis(0)->isStream()) << tv; + } +} + } // namespace nvfuser