-
Notifications
You must be signed in to change notification settings - Fork 79
Safe resize vectorization #3906
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
38 commits
Select commit
Hold shift + click to select a range
212d522
WIP
naoyam 29c2489
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 0e829b8
WAR for vectorization of resize
naoyam 82d274d
Revert size changes
naoyam fc9452b
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 1fffd78
fix
naoyam 6426244
build fix
naoyam dbb47ca
build fix
naoyam bd4ad8b
cleanup
naoyam 01c70e3
WIP
naoyam 0368167
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam a918efc
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 63e4d68
test fix
naoyam 7d4d82b
fix
naoyam 6e265d3
build fix
naoyam b890e89
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam f4f5b89
cleanup
naoyam 386d01c
remove debug print
naoyam 23f8b54
cleanup
naoyam a6b29c8
skip failing test
naoyam abf346f
cleanup
naoyam d93c12f
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 31f0ed4
cleanup
naoyam bb92176
fix
naoyam 0d15c9d
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 357f684
Cache
naoyam 51f6efc
comments
naoyam c4a0b87
cleanup
naoyam 0fbf0f0
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 220c5b1
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam 3c71a92
revert
naoyam 8967676
rephrase
naoyam b739b41
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam d8e3f05
comment
naoyam 928e555
cleanup
naoyam 999ce45
cleanup
naoyam 546d7ea
cleanup
naoyam 58dcb3a
Merge remote-tracking branch 'origin/main' into traverse_all_paths
naoyam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ | |
| #include <iter_visitor.h> | ||
| #include <scheduler/registry.h> | ||
| #include <scheduler/runtime_info.h> | ||
| #include <scheduler/tools/resize_utils.h> | ||
| #include <val_graph_visitor.h> | ||
|
|
||
| #include <c10/util/irange.h> | ||
|
|
@@ -843,6 +844,96 @@ std::vector<std::unordered_map<TensorView*, Val*>> getTvToContigInnerSizeMapsOf( | |
| return mappers; | ||
| } | ||
|
|
||
| // This is a WAR for vectorizing through resized iter domains. The | ||
| // spanning tree based analysis is not guaranteed to take all resize | ||
| // ops into considerations (issue | ||
| // https://github.com/NVIDIA/Fuser/issues/3640). To workaround the | ||
| // limitation, grab all factors that must be divisible by a | ||
| // vectorization factors. | ||
| std::unordered_set<Val*> getResizeVectorizationFactors( | ||
| TensorView* reference_tv, | ||
| int64_t break_point) { | ||
| Fusion* fusion = reference_tv->fusion(); | ||
| std::unordered_set<Val*> factors; | ||
| const auto resize_based_ops = scheduler_tools::getResizeBasedOps(fusion); | ||
|
|
||
| if (resize_based_ops.empty()) { | ||
| return factors; | ||
| } | ||
|
|
||
| IdModel id_model(reference_tv->fusion()); | ||
| const auto& graph = id_model.buildExactGraph(); | ||
|
|
||
| const auto ref_groups = graph.toGroups(reference_tv->getLogicalDomain()); | ||
|
|
||
| // For each of resize-based tensor ops, find all resize ops | ||
| // that exist between the vectorized reference IDs and the output | ||
| // tensor. | ||
| for (auto resize_based_op : resize_based_ops) { | ||
| auto resize_out = resize_based_op->output(0)->as<TensorView>(); | ||
| NVF_ERROR( | ||
| resize_out->hasRoot(), "Unexpected op: ", resize_based_op->toString()); | ||
| // getAllExprGroupsBetween finds exprs between IDs. To make sure | ||
| // the the resize op of this resize_based_op tensor op is found, | ||
| // use both the root and logical domains as the traversal targets. | ||
| ValGroups resize_inp_out; | ||
| resize_inp_out.pushBack(graph.toGroups(resize_out->getRootDomain())); | ||
| resize_inp_out.pushBack(graph.toGroups(resize_out->getLogicalDomain())); | ||
|
|
||
| auto expr_path = getAllExprGroupsBetween( | ||
| graph, | ||
| ref_groups, | ||
| resize_inp_out, | ||
| /*require_all_to_visited=*/false) | ||
| .first; | ||
|
|
||
| ValGroups vectorized_groups; | ||
| for (auto it = reference_tv->getLogicalDomain().begin() + break_point; | ||
| it != reference_tv->getLogicalDomain().end(); | ||
| ++it) { | ||
| vectorized_groups.pushBack(graph.toGroup(*it)); | ||
| } | ||
|
|
||
| // Find all resize exprs that appear in expr_path and depend on | ||
| // vectorized_groups. Since expr_path is not guaranteed to be | ||
| // topologically sorted, need to loop through the path until | ||
| // converged. | ||
|
|
||
| bool something_has_changed = true; | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| while (something_has_changed) { | ||
| something_has_changed = false; | ||
| for (const auto& [expr_g, dir] : expr_path) { | ||
| const auto inputs = getInputsOfExprGroup(graph, expr_g, dir); | ||
| if (std::none_of( | ||
| inputs.begin(), inputs.end(), [&](const ValGroup& inp) { | ||
| return vectorized_groups.has(inp); | ||
| })) { | ||
| continue; | ||
| } | ||
|
|
||
| if (vectorized_groups.pushBack( | ||
| getOutputsOfExprGroup(graph, expr_g, dir))) { | ||
| something_has_changed = true; | ||
| } | ||
|
|
||
| auto resize = dynamic_cast<Resize*>(expr_g->front()); | ||
| if (resize == nullptr) { | ||
| continue; | ||
| } | ||
|
|
||
| // These three vals need to be divisible | ||
| factors.emplace(resize->leftExpand()); | ||
| factors.emplace(resize->rightExpand()); | ||
| factors.emplace( | ||
| dir == Direction::Forward ? resize->out()->extent() | ||
| : resize->in()->extent()); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return factors; | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
| int64_t getVectorizationFactor( | ||
|
|
@@ -881,6 +972,15 @@ int64_t getVectorizationFactor( | |
| return 1; | ||
| } | ||
|
|
||
| auto resize_factors_entry = | ||
| HeuristicDataCacheEntry<HeuristicCompileTime::ResizeVectorizationFactors>( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 👏 |
||
| data_cache, [&reference_tv, &break_point]() { | ||
| return std::make_unique<std::unordered_set<Val*>>( | ||
| getResizeVectorizationFactors(reference_tv, break_point)); | ||
| }); | ||
|
|
||
| const auto& resize_factors = resize_factors_entry.get(); | ||
|
|
||
| int64_t max_vec_size = SchedulerRuntimeInfo::max_alignment_size_in_byte; | ||
| const auto& tv_to_inner_size_map = vectorize_maps_entry.get().at(break_point); | ||
|
|
||
|
|
@@ -920,6 +1020,19 @@ int64_t getVectorizationFactor( | |
| max_vec_size); | ||
| } | ||
|
|
||
| // This is a WAR for vectorization through resize as the spanning | ||
| // tree based traversal is not guaranteed to reflect all resize ops | ||
| // that may affect vectorization. This is a safe but conservative | ||
| // analysis since it should only be necessary for innermost IDs. | ||
| for (const auto resize_factor : resize_factors) { | ||
| auto inferred_val = | ||
| runtime_info.expressionEvaluator().evaluate(resize_factor); | ||
| if (!inferred_val.hasValue()) { | ||
| return 1; | ||
| } | ||
| max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>()); | ||
jjsjann123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
|
||
| return max_vec_size; | ||
| } | ||
|
|
||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.