-
Notifications
You must be signed in to change notification settings - Fork 79
More precise WAR for resize vectorization #4305
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -842,6 +842,59 @@ std::vector<std::unordered_map<TensorView*, Val*>> getTvToContigInnerSizeMapsOf( | |
| return mappers; | ||
| } | ||
|
|
||
| // Check if a traversal from vectorized reference IDs may reach the | ||
| // IDs of a resize expr without visiting the Resize expr itself. That's | ||
| // problematic for the vectorization analysis as the spanning-tree | ||
| // based analysis may miss the constraint by the Resize expr. | ||
| // | ||
| // For this analysis, we start a traversal from the vectorized | ||
| // reference IDs to both the input and output of the Resize expr but | ||
| // disallow visiting the Resize expr itself. If the traversal is still | ||
| // successful, it means there's a path from the reference IDs to the | ||
| // resize input and output IDs without visiting the Resize expr. | ||
| // | ||
| // Permissive BFS is used in this traversal as the vectorized | ||
| // reference IDs may not have all the dependencies for the | ||
| // traversal. For example, suppose there's a split resshape, and only | ||
| // the innermost ID is vectorized. The standard BFS is not able to | ||
| // move forward if only the vectorized ID is give as the backward | ||
| // split requires both outputs to be presented. | ||
| class CanSkipResize : public ValGraphPermissiveBFS { | ||
| public: | ||
| static bool run( | ||
| const ValGraph& graph, | ||
| const ValGroups& ref_groups, | ||
| Resize* resize) { | ||
| ValGroups resize_in_out_groups; | ||
| resize_in_out_groups.pushBack(graph.toGroup(resize->in())); | ||
| resize_in_out_groups.pushBack(graph.toGroup(resize->out())); | ||
| CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize); | ||
| bfs.traverse(); | ||
| return bfs.allToNodesVisited(); | ||
| } | ||
|
|
||
| CanSkipResize( | ||
| const ValGraph& graph, | ||
| const ValGroups& ref_groups, | ||
| const ValGroups& resize_in_out_groups, | ||
| Resize* resize) | ||
| : ValGraphPermissiveBFS( | ||
| graph, | ||
| {ref_groups.begin(), ref_groups.end()}, | ||
| {resize_in_out_groups.begin(), resize_in_out_groups.end()}, | ||
| /*require_all_to_visited=*/false, | ||
| /*allowed_direction=*/Direction::Undefined), | ||
| resize_(resize) {} | ||
|
|
||
| bool excludeFromTraversal(const NodeType& node) const override { | ||
| const ExprGroup* e = std::get_if<ExprGroup>(&node); | ||
| return e != nullptr && (*e)->has(resize_); | ||
| } | ||
|
|
||
| private: | ||
| Resize* resize_ = nullptr; | ||
| }; | ||
|
|
||
| // This is a WAR for vectorizing through resized iter domains. The | ||
| // spanning tree based analysis is not guaranteed to take all resize | ||
| // ops into considerations (issue | ||
|
|
@@ -852,84 +905,48 @@ std::unordered_set<Val*> getResizeVectorizationFactors( | |
| TensorView* reference_tv, | ||
| int64_t break_point) { | ||
| Fusion* fusion = reference_tv->fusion(); | ||
| std::unordered_set<Val*> factors; | ||
| const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); | ||
|
|
||
| if (resize_based_ops.empty()) { | ||
| return factors; | ||
| return {}; | ||
| } | ||
|
|
||
| IdModel id_model(reference_tv->fusion()); | ||
| IdModel id_model(fusion); | ||
| const auto& graph = id_model.buildExactGraph(); | ||
|
|
||
| const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); | ||
| std::unordered_set<Val*> resize_factors; | ||
|
|
||
| // For each of resize-based tensor ops, find all resize ops | ||
| // that exist between the vectorized reference IDs and the output | ||
| // tensor. | ||
| for (auto resize_based_op : resize_based_ops) { | ||
| auto resize_out = resize_based_op->output(0)->as<TensorView>(); | ||
| NVF_ERROR( | ||
| resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); | ||
| // getAllExprGroupsBetween finds exprs between IDs. To make sure | ||
| // the the resize op of this resize_based_op tensor op is found, | ||
| // use both the root and logical domains as the traversal targets. | ||
| ValGroups resize_inp_out; | ||
| resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); | ||
| resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); | ||
|
|
||
| auto expr_path = getAllExprGroupsBetween( | ||
| graph, | ||
| ref_groups, | ||
| resize_inp_out, | ||
| /*require_all_to_visited=*/false) | ||
| .first; | ||
|
|
||
| ValGroups vectorized_groups; | ||
| for (auto it = reference_tv->getLogicalDomain().begin() + break_point; | ||
| it != reference_tv->getLogicalDomain().end(); | ||
| ++it) { | ||
| vectorized_groups.pushBack(graph.toGroup(*it)); | ||
| auto add_resize_factors = [&](Resize* resize) { | ||
| if (!resize->leftExpand()->isZeroInt()) { | ||
| resize_factors.insert(resize->leftExpand()); | ||
| } | ||
| if (!resize->rightExpand()->isZeroInt()) { | ||
| resize_factors.insert(resize->rightExpand()); | ||
| } | ||
| }; | ||
|
|
||
| // Find all resize exprs that appear in expr_path and depend on | ||
| // vectorized_groups. Since expr_path is not guaranteed to be | ||
| // topologically sorted, need to loop through the path until | ||
| // converged. | ||
|
|
||
| bool something_has_changed = true; | ||
| while (something_has_changed) { | ||
| something_has_changed = false; | ||
| for (const auto& [expr_g, dir] : expr_path) { | ||
| const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); | ||
| if (std::none_of( | ||
| inputs.begin(), inputs.end(), [&](const ValGroup& inp) { | ||
| return vectorized_groups.has(inp); | ||
| })) { | ||
| continue; | ||
| } | ||
|
|
||
| if (vectorized_groups.pushBack( | ||
| getOutputsOfExprGroup(graph, expr_g, dir))) { | ||
| something_has_changed = true; | ||
| } | ||
|
|
||
| auto resize = dynamic_cast<Resize*>(expr_g->front()); | ||
| if (resize == nullptr) { | ||
| continue; | ||
| } | ||
| const ValGroups ref_vec_groups = graph.toGroups(std::vector<Val*>{ | ||
| reference_tv->getLogicalDomain().begin() + break_point, | ||
| reference_tv->getLogicalDomain().end()}); | ||
|
|
||
| // For each of Resize exprs, if it's reachable from the reference | ||
| // vectorized IDs without visiting the Resize expr itself, its | ||
| // constraint may not be reflectd in the inner sizes. | ||
| for (auto resize : resize_based_ops) { | ||
| auto resize_out_tv = resize->output(0)->as<TensorView>(); | ||
| for (const auto logical_id : resize_out_tv->getLogicalDomain()) { | ||
| auto resize = dynamic_cast<Resize*>(logical_id->definition()); | ||
| if (resize == nullptr) { | ||
| continue; | ||
| } | ||
|
|
||
| // These three vals need to be divisible | ||
| factors.emplace(resize->leftExpand()); | ||
| factors.emplace(resize->rightExpand()); | ||
| factors.emplace( | ||
| dir == Direction::Forward ? resize->out()->extent() | ||
| : resize->in()->extent()); | ||
| if (CanSkipResize::run(graph, ref_vec_groups, resize)) { | ||
| add_resize_factors(resize); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return factors; | ||
| return resize_factors; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
@@ -1028,7 +1045,11 @@ int64_t getVectorizationFactor( | |
| if (!inferred_val.hasValue()) { | ||
| return 1; | ||
| } | ||
| max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>()); | ||
| auto inferred_val_int = inferred_val.as<int64_t>(); | ||
| if (inferred_val_int == 0) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is for dynamic resize extents that would be 0? |
||
| continue; | ||
| } | ||
| max_vec_size = std::gcd(max_vec_size, inferred_val_int); | ||
| } | ||
|
|
||
| return max_vec_size; | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5970,7 +5970,7 @@ TEST_F(ResizeTest, AvoidCachingSliceInput) { | |
| } | ||
| } | ||
|
|
||
| TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { | ||
| TEST_F(ResizeTest, VectorizeInnerSliceMultiplePaths) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| auto& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
@@ -6005,6 +6005,50 @@ TEST_F(ResizeTest, VectorizeSliceMultiplePaths) { | |
| EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); | ||
| } | ||
|
|
||
| // The current analysis is not precise enough to pass this test | ||
|
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this test, tv0 is resized in two different ways. The spanning-tree based analysis is not guaranteed to correctly identify the vectorization constraint. The WAR when applied to this case is still too conservative. |
||
| TEST_F(ResizeTest, DISABLED_VectorizeOuterSliceMultiplePaths) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| auto& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
||
| const std::vector<int64_t> shape{4, 1024 * 1024}; | ||
|
|
||
| auto tv0 = makeContigConcreteTensor(shape); | ||
| fusion.addInput(tv0); | ||
|
|
||
| auto tv1 = | ||
| pad(tv0, | ||
| {fusion.zeroVal(), | ||
| fusion.zeroVal(), | ||
| IrBuilder::create<Val>(2), | ||
| IrBuilder::create<Val>(2)}); | ||
| auto tv2 = | ||
| pad(tv0, | ||
| {fusion.zeroVal(), | ||
| fusion.zeroVal(), | ||
| fusion.zeroVal(), | ||
| IrBuilder::create<Val>(4)}); | ||
| auto tv3 = add(tv1, tv2); | ||
| fusion.addOutput(tv3); | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); | ||
| auto t0 = at::randn(shape, options); | ||
|
|
||
| auto outputs = scheduleAndRun(&fusion, SchedulerType::PointWise, {t0}); | ||
| testValidate(&fusion, outputs.outputs, {t0}, __LINE__, __FILE__); | ||
|
|
||
| // While there's a pad with factor of 2, it shouldn't matter as the | ||
| // inner ID is large enough. | ||
| auto out_tv = tv3; | ||
| auto vec_id_it = | ||
| std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { | ||
| return loop_id->getParallelType() == ParallelType::Vectorize; | ||
| }); | ||
| ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) | ||
| << "Vectorized ID not found: " << out_tv->toString(); | ||
| EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 4); | ||
| } | ||
|
|
||
| // Repro of issue #4202 | ||
| TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
|
|
@@ -6040,4 +6084,49 @@ TEST_F(ResizeTest, PropagateResizeThroughMultiplePaths) { | |
| testValidate(&fusion, outputs.outputs, {t0, t1}, __LINE__, __FILE__); | ||
| } | ||
|
|
||
| // Check if vectorization is properly applied even when a resized ID | ||
| // is reachable from vectorized IDs. Pattern extracted from Litgpt | ||
| // LLama RoPE backward. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nitpick on comment. This is the case where it's safe to skip the additional resize check. So that means the resized ID is NOT reachable from vectorized IDs. |
||
| TEST_F(ResizeTest, VectorizeOuterPad) { | ||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||
| auto& fusion = *fusion_ptr; | ||
| FusionGuard fg(fusion_ptr.get()); | ||
|
|
||
| const std::vector<int64_t> shape1{1, 8, 4, 8192, 128}; | ||
| const std::vector<int64_t> shape2{1, 8, 1, 8192, 128}; | ||
| auto tv0 = makeContigConcreteTensor(shape1, DataType::BFloat16); | ||
| fusion.addInput(tv0); | ||
| auto tv1 = makeContigConcreteTensor(shape2, DataType::BFloat16); | ||
| fusion.addInput(tv1); | ||
| auto tv2 = makeContigConcreteTensor(shape2, DataType::BFloat16); | ||
| fusion.addInput(tv2); | ||
|
|
||
| // [1, 8, 6, 8192, 128] | ||
| auto tv3 = cat({tv0, tv1, tv2}, 2); | ||
| // [1, 8192, 8, 6, 128] | ||
| auto tv4 = permute(tv3, {0, 3, 1, 2, 4}); | ||
| auto tv5 = reshape(tv4, {1, 8192, 8, 6, 128}, {1, 8192, 6144}); | ||
| fusion.addOutput(tv5); | ||
|
|
||
| auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); | ||
| auto t0 = at::randn(shape1, options); | ||
| auto t1 = at::randn(shape2, options); | ||
| auto t2 = at::randn(shape2, options); | ||
|
|
||
| auto outputs = | ||
| scheduleAndRun(&fusion, SchedulerType::PointWise, {t0, t1, t2}); | ||
| testValidate(&fusion, outputs.outputs, {t0, t1, t2}, __LINE__, __FILE__); | ||
|
|
||
| auto out_tv = tv5; | ||
| // While there's a pad with factor of 2, it shouldn't matter as the | ||
| // inner ID is large enough. | ||
| auto vec_id_it = | ||
| std::ranges::find_if(out_tv->getLoopDomain(), [](IterDomain* loop_id) { | ||
| return loop_id->getParallelType() == ParallelType::Vectorize; | ||
| }); | ||
| ASSERT_NE(vec_id_it, out_tv->getLoopDomain().end()) | ||
| << "Vectorized ID not found: " << out_tv->toString(); | ||
| EXPECT_EQ((*vec_id_it)->extent()->evaluate(), 8); | ||
| } | ||
|
|
||
| } // namespace nvfuser | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qq: here we are calling
allToNodesVisited()? but the init function below has/*require_all_to_visited=*/false,, so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.But the function name is somewhat confusing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the traversal should continue until no further progress is made. The
require_all_to_visitedflag means it's considered an error if not all of thetonodes were not able to reach.Here, we just want to check all of the
tonodes are reachable. It isn't an error even if not.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sorry I got confused myself. This returns true indicating it's safe to skip the check. So
allToNodesVisitedis the proper name for the function.Thanks for elaborating on this one.