diff --git a/csrc/scheduler/vectorize_helper.cpp b/csrc/scheduler/vectorize_helper.cpp index 6ec81585869..091f3519a0d 100644 --- a/csrc/scheduler/vectorize_helper.cpp +++ b/csrc/scheduler/vectorize_helper.cpp @@ -842,6 +842,59 @@ std::vector> getTvToContigInnerSizeMapsOf( return mappers; } +// Check if a traversal from vectorized reference IDs may reach the +// IDs of a resize expr without visiting the Resize expr itself. That's +// problematic for the vectorization analysis as the spanning-tree +// based analysis may miss the constraint by the Resize expr. +// +// For this analysis, we start a traversal from the vectorized +// reference IDs to both the input and output of the Resize expr but +// disallow visiting the Resize expr itself. If the traversal is still +// successful, it means there's a path from the reference IDs to the +// resize input and output IDs without visiting the Resize expr. +// +// Permissive BFS is used in this traversal as the vectorized +// reference IDs may not have all the dependencies for the +// traversal. For example, suppose there's a split resshape, and only +// the innermost ID is vectorized. The standard BFS is not able to +// move forward if only the vectorized ID is give as the backward +// split requires both outputs to be presented. +class CanSkipResize : public ValGraphPermissiveBFS { + public: + static bool run( + const ValGraph& graph, + const ValGroups& ref_groups, + Resize* resize) { + ValGroups resize_in_out_groups; + resize_in_out_groups.pushBack(graph.toGroup(resize->in())); + resize_in_out_groups.pushBack(graph.toGroup(resize->out())); + CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize); + bfs.traverse(); + return bfs.allToNodesVisited(); + } + + CanSkipResize( + const ValGraph& graph, + const ValGroups& ref_groups, + const ValGroups& resize_in_out_groups, + Resize* resize) + : ValGraphPermissiveBFS( + graph, + {ref_groups.begin(), ref_groups.end()}, + {resize_in_out_groups.begin(), resize_in_out_groups.end()}, + /*require_all_to_visited=*/false, + /*allowed_direction=*/Direction::Undefined), + resize_(resize) {} + + bool excludeFromTraversal(const NodeType& node) const override { + const ExprGroup* e = std::get_if(&node); + return e != nullptr && (*e)->has(resize_); + } + + private: + Resize* resize_ = nullptr; +}; + // This is a WAR for vectorizing through resized iter domains. The // spanning tree based analysis is not guaranteed to take all resize // ops into considerations (issue @@ -852,84 +905,48 @@ std::unordered_set getResizeVectorizationFactors( TensorView* reference_tv, int64_t break_point) { Fusion* fusion = reference_tv->fusion(); - std::unordered_set factors; const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); if (resize_based_ops.empty()) { - return factors; + return {}; } - IdModel id_model(reference_tv->fusion()); + IdModel id_model(fusion); const auto& graph = id_model.buildExactGraph(); - const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); + std::unordered_set resize_factors; - // For each of resize-based tensor ops, find all resize ops - // that exist between the vectorized reference IDs and the output - // tensor. - for (auto resize_based_op : resize_based_ops) { - auto resize_out = resize_based_op->output(0)->as(); - NVF_ERROR( - resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); - // getAllExprGroupsBetween finds exprs between IDs. To make sure - // the the resize op of this resize_based_op tensor op is found, - // use both the root and logical domains as the traversal targets. - ValGroups resize_inp_out; - resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); - resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); - - auto expr_path = getAllExprGroupsBetween( - graph, - ref_groups, - resize_inp_out, - /*require_all_to_visited=*/false) - .first; - - ValGroups vectorized_groups; - for (auto it = reference_tv->getLogicalDomain().begin() + break_point; - it != reference_tv->getLogicalDomain().end(); - ++it) { - vectorized_groups.pushBack(graph.toGroup(*it)); + auto add_resize_factors = [&](Resize* resize) { + if (!resize->leftExpand()->isZeroInt()) { + resize_factors.insert(resize->leftExpand()); } + if (!resize->rightExpand()->isZeroInt()) { + resize_factors.insert(resize->rightExpand()); + } + }; - // Find all resize exprs that appear in expr_path and depend on - // vectorized_groups. Since expr_path is not guaranteed to be - // topologically sorted, need to loop through the path until - // converged. - - bool something_has_changed = true; - while (something_has_changed) { - something_has_changed = false; - for (const auto& [expr_g, dir] : expr_path) { - const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); - if (std::none_of( - inputs.begin(), inputs.end(), [&](const ValGroup& inp) { - return vectorized_groups.has(inp); - })) { - continue; - } - - if (vectorized_groups.pushBack( - getOutputsOfExprGroup(graph, expr_g, dir))) { - something_has_changed = true; - } - - auto resize = dynamic_cast(expr_g->front()); - if (resize == nullptr) { - continue; - } + const ValGroups ref_vec_groups = graph.toGroups(std::vector{ + reference_tv->getLogicalDomain().begin() + break_point, + reference_tv->getLogicalDomain().end()}); + + // For each of Resize exprs, if it's reachable from the reference + // vectorized IDs without visiting the Resize expr itself, its + // constraint may not be reflectd in the inner sizes. + for (auto resize : resize_based_ops) { + auto resize_out_tv = resize->output(0)->as(); + for (const auto logical_id : resize_out_tv->getLogicalDomain()) { + auto resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } - // These three vals need to be divisible - factors.emplace(resize->leftExpand()); - factors.emplace(resize->rightExpand()); - factors.emplace( - dir == Direction::Forward ? resize->out()->extent() - : resize->in()->extent()); + if (CanSkipResize::run(graph, ref_vec_groups, resize)) { + add_resize_factors(resize); } } } - return factors; + return resize_factors; } } // namespace @@ -1028,7 +1045,11 @@ int64_t getVectorizationFactor( if (!inferred_val.hasValue()) { return 1; } - max_vec_size = std::gcd(max_vec_size, inferred_val.as()); + auto inferred_val_int = inferred_val.as(); + if (inferred_val_int == 0) { + continue; + } + max_vec_size = std::gcd(max_vec_size, inferred_val_int); } return max_vec_size; diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 12ca31a929f..4aaac477a52 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -5970,7 +5970,7 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) { } } -TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { +TEST_F(ResizeTest, VectorizeInnerSliceMultiplePaths) { auto fusion_ptr = std::make_unique(); auto& fusion = *fusion_ptr; FusionGuard fg(fusion_ptr.get()); @@ -6005,6 +6005,50 @@ TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); } +// The current analysis is not precise enough to pass this test +TEST_F(ResizeTest, DISABLED_VectorizeOuterSliceMultiplePaths) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape{4, 1024 * 1024}; + + auto tv0 = makeContigConcreteTensor(shape); + fusion.addInput(tv0); + + auto tv1 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(2), + IrBuilder::create(2)}); + auto tv2 = + pad(tv0, + {fusion.zeroVal(), + fusion.zeroVal(), + fusion.zeroVal(), + IrBuilder::create(4)}); + auto tv3 = add(tv1, tv2); + fusion.addOutput(tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + auto t0 = at::randn(shape, options); + + auto outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, {t0}); + testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); + + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto out_tv = tv3; + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 4); +} + // Repro of issue #4202 TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { auto fusion_ptr = std::make_unique(); @@ -6040,4 +6084,49 @@ TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { testValidate(&fusion, outputs.outputs, {t0, t1}, __LINE__, __FILE__); } +// Check if vectorization is properly applied even when a resized ID +// is reachable from vectorized IDs. Pattern extracted from Litgpt +// LLama RoPE backward. +TEST_F(ResizeTest, VectorizeOuterPad) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + const std::vector shape1{1, 8, 4, 8192, 128}; + const std::vector shape2{1, 8, 1, 8192, 128}; + auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); + fusion.addInput(tv0); + auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv1); + auto tv2 = makeContigConcreteTensor(shape2, DataType::BFloat16); + fusion.addInput(tv2); + + // [1, 8, 6, 8192, 128] + auto tv3 = cat({tv0, tv1, tv2}, 2); + // [1, 8192, 8, 6, 128] + auto tv4 = permute(tv3, {0, 3, 1, 2, 4}); + auto tv5 = reshape(tv4, {1, 8192, 8, 6, 128}, {1, 8192, 6144}); + fusion.addOutput(tv5); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn(shape1, options); + auto t1 = at::randn(shape2, options); + auto t2 = at::randn(shape2, options); + + auto outputs = + scheduleAndRun(&fusion, SchedulerType::PointWise, {t0, t1, t2}); + testValidate(&fusion, outputs.outputs, {t0, t1, t2}, __LINE__, __FILE__); + + auto out_tv = tv5; + // While there's a pad with factor of 2, it shouldn't matter as the + // inner ID is large enough. + auto vec_id_it = + std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { + return loop_id->getParallelType() == ParallelType::Vectorize; + }); + ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) + << "Vectorized ID not found: " << out_tv->toString(); + EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 8); +} + } // namespace nvfuser