-
Notifications
You must be signed in to change notification settings - Fork 79
Add a visitor for ValGraph #1713
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
bbd7786
Add a visitor for ValGraph
naoyam 42b1b23
comments
naoyam c6a6928
Remove a stale comment
naoyam a8a3bbf
Avoid using popFront
naoyam e43bd3c
Merge branch 'main' into valgraph_visitor
naoyam ebd14d2
Update the reference order (still valid)
naoyam bef4a78
Remove popFront as it's no longer used
naoyam 2b02316
Rename id to val
naoyam 8accbd9
Remove accidentally added files
naoyam e33a005
Refactor traversal loop
naoyam b92ec79
cleanup
naoyam 0059992
Update csrc/val_graph_visitor.cpp
naoyam cfc784e
Update csrc/val_graph_visitor.cpp
naoyam b891e36
handle trivial expr
naoyam d1ca512
Update csrc/val_graph_visitor.cpp
naoyam 117513a
clang-format
naoyam a66bdb9
typo
naoyam 85948b6
Merge branch 'main' into valgraph_visitor
naoyam 7d133cc
Merge branch 'main' into valgraph_visitor
naoyam 465146d
cleanup
naoyam 208edf3
Merge branch 'main' into valgraph_visitor
naoyam File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,152 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #include <val_graph_visitor.h> | ||
|
|
||
| #include <id_model/to_string.h> | ||
|
|
||
| namespace nvfuser { | ||
|
|
||
| void ValGraphVisitor::traverse() { | ||
| const ValGroups terminating_inputs = graph().getTerminatingInputs(); | ||
| std::deque<ValGroup> to_visit_vals( | ||
| terminating_inputs.begin(), terminating_inputs.end()); | ||
| ValGroups visited_vals; | ||
|
|
||
| std::deque<ExprGroup> to_visit_exprs; | ||
| ExprGroups visited_exprs; | ||
|
|
||
| auto is_expr_ready = [&](const ExprGroup& expr_group) -> bool { | ||
| const auto inp_groups = graph().inputGroups(expr_group); | ||
| return std::all_of( | ||
| inp_groups.begin(), inp_groups.end(), [&](ValGroup val_group) { | ||
| return visited_vals.has(val_group) || val_group->empty(); | ||
| }); | ||
| }; | ||
|
|
||
| // If any input of the def expr is mapped with the val | ||
| // group itself, i.e., a trivial expr, allow visiting the | ||
| // val group first. The trivial expr group will be visited | ||
| // after the val group. | ||
| // | ||
| // Example: | ||
| // | ||
| // [i0, 1] | ||
| // merge | ||
| // [i0*1] | ||
| // map i0 and i0*1 | ||
| // ValGroups: {{i0, i0*1}, {1}} | ||
| // | ||
| // Then, {i0, i0*1} and {1} would be visited first, then the merge | ||
| // expr group would be visited. {i0, i0*1} is also an output group | ||
| // of the merge but since it's already in the visited set, it would | ||
| // not be visited again. | ||
| // | ||
| // See also IdModelTest.ValGraphStmtSort3 for a concrete example. | ||
| auto is_val_ready = [&](const ValGroup& val_group) -> bool { | ||
| if (terminating_inputs.has(val_group)) { | ||
| return true; | ||
| } | ||
| const ExprGroups& unique_defs = graph().getDefinitions(val_group); | ||
| return std::all_of( | ||
| unique_defs.begin(), unique_defs.end(), [&](ExprGroup expr_group) { | ||
| if (expr_group->empty() || visited_exprs.has(expr_group)) { | ||
| return true; | ||
| } | ||
| // Handle ExprGroups that return one or some of its input ValGroups as | ||
| // output. This expr_group is not visited yet, which means there're | ||
| // input ValGroups that are not yet visited. If those not-visited | ||
| // inputs are actually the same as val_group, visit val_group at this | ||
| // point to resolve the circular dependency. | ||
| for (const ValGroup& input_group : graph().inputGroups(expr_group)) { | ||
| if (input_group != val_group && !visited_vals.has(input_group) && | ||
| input_group->empty()) { | ||
| return false; | ||
| } | ||
| } | ||
| return true; | ||
| }); | ||
| }; | ||
|
|
||
| // Detect if nothing has been processed which would put us in an infinite | ||
| // loop | ||
| bool something_was_processed = false; | ||
|
|
||
| do { | ||
| something_was_processed = false; | ||
|
|
||
| // Process expressions first as all definitions of vals have to be | ||
| // processed before we can process that val. | ||
|
|
||
| while (!to_visit_exprs.empty()) { | ||
| ExprGroup current_expr_group = to_visit_exprs.front(); | ||
| to_visit_exprs.pop_front(); | ||
| NVF_ERROR(!current_expr_group->empty()); | ||
| if (visited_exprs.has(current_expr_group)) { | ||
| continue; | ||
| } | ||
|
|
||
| if (is_expr_ready(current_expr_group)) { | ||
| handle(current_expr_group); | ||
|
|
||
| something_was_processed = true; | ||
| visited_exprs.pushBack(current_expr_group); | ||
|
|
||
| for (const ValGroup& output_group : | ||
| graph().outputGroups(current_expr_group)) { | ||
| to_visit_vals.push_back(output_group); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| std::deque<ValGroup> still_to_visit_vals; | ||
| while (!to_visit_vals.empty()) { | ||
| auto current_val_group = to_visit_vals.front(); | ||
| to_visit_vals.pop_front(); | ||
| NVF_ERROR(!current_val_group->empty()); | ||
| if (visited_vals.has(current_val_group)) { | ||
| continue; | ||
| } | ||
|
|
||
| if (is_val_ready(current_val_group)) { | ||
| handle(current_val_group); | ||
|
|
||
| something_was_processed = true; | ||
| visited_vals.pushBack(current_val_group); | ||
|
|
||
| for (const ExprGroup& use_group : graph().getUses(current_val_group)) { | ||
| to_visit_exprs.push_back(use_group); | ||
| } | ||
| } else { | ||
| still_to_visit_vals.push_back(current_val_group); | ||
| } | ||
| } | ||
|
|
||
| std::swap(to_visit_vals, still_to_visit_vals); | ||
|
|
||
| } while (something_was_processed); | ||
|
|
||
| if (!to_visit_vals.empty()) { | ||
| std::stringstream ss; | ||
| ss << "The graph has an infinite loop. The following Vals should be visited but are never ready:"; | ||
| for (const ValGroup& vg : to_visit_vals) { | ||
| ss << " " << nvfuser::toString(vg); | ||
| } | ||
| NVF_ERROR(false, ss.str()); | ||
| } | ||
|
|
||
| if (!to_visit_exprs.empty()) { | ||
| std::stringstream ss; | ||
| ss << "The graph has an infinite loop. The following Exprs should be visited but are never ready:"; | ||
| for (const ExprGroup& eg : to_visit_exprs) { | ||
| ss << " " << nvfuser::toString(eg); | ||
| } | ||
| NVF_ERROR(false, ss.str()); | ||
| } | ||
| } | ||
|
|
||
| } // namespace nvfuser | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,121 @@ | ||
| // clang-format off | ||
| /* | ||
| * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
| * All rights reserved. | ||
| * SPDX-License-Identifier: BSD-3-Clause | ||
| */ | ||
| // clang-format on | ||
| #pragma once | ||
|
|
||
| #include <disjoint_set.h> | ||
| #include <ir/all_nodes.h> | ||
| #include <val_graph.h> | ||
|
|
||
| namespace nvfuser { | ||
|
|
||
| // Iterates through a Val Graph in topological order, calling handle on | ||
| // all Val and all Expr groups in a forward topological order. | ||
| // | ||
| // Warning: A ValGraph is not guaranteed to be a DAG. In fact, the | ||
| // AlmostExact and Permissive graphs would have cycles with a ValGroup | ||
| // and an ExprGroup. For example: | ||
| // | ||
| // [i0, 1] | ||
| // merge | ||
| // [i0*1] | ||
| // Current ValGroups: {{i0}, {1}, {i0*1}} | ||
| // map i0 and i0*1 as they effectively have the same extent | ||
| // Final ValGroups: {{i0, i0*1}, {1}} | ||
| // | ||
| // Here, the merge expr is the user of i0 and the definition of | ||
| // i0*1. Since i0 and i0*1 are mapped, the dependency chain looks | ||
| // like: | ||
| // | ||
| // {i0, i0*1} ----> {merge} ----> {i0, i0*1} | ||
| // use def | ||
| // | ||
| // These ExprGroups are called trivial ExprGroups (see also | ||
| // ValGraph::isTrivialExprGroup). | ||
| // | ||
| // Strictly speaking, these cycles mean there's no valid topological | ||
| // order anymore. In our use cases for IdModel, however, it's likely | ||
| // sufficient to return an ordering such as: | ||
| // | ||
| // {i0, i0*1} -> {merge} | ||
| // | ||
| // I.e., we visit {i0, i0*1} first even though {merge} is technically | ||
| // a definition. | ||
| // | ||
| // Another alternative may be simply giving up when such a cycle is | ||
| // detected, which may be more preferrable as it would be less | ||
| // confusing. At this moment, this visitor is only used with graphs | ||
| // with no such cycle. Should be revisited when necessary. | ||
| // | ||
| // Warning: This is not a great iterator if there's a desire to minimize paths | ||
| // traveled to simply visit all ValGroups in order. See ExprsBetween to see how | ||
| // we might minimize paths. | ||
| class ValGraphVisitor { | ||
| public: | ||
| ValGraphVisitor() = delete; | ||
|
|
||
| ValGraphVisitor& operator=(const ValGraphVisitor& other) = delete; | ||
|
|
||
| ValGraphVisitor& operator=(ValGraphVisitor&& other) = delete; | ||
|
|
||
| virtual ~ValGraphVisitor() = default; | ||
|
|
||
| protected: | ||
| ValGraphVisitor(const ValGraph& val_graph) : val_graph_(val_graph) {} | ||
|
|
||
| ValGraphVisitor(const ValGraphVisitor& other) = default; | ||
|
|
||
| ValGraphVisitor(ValGraphVisitor&& other) = default; | ||
|
|
||
| virtual void handle(const ValGroup& val_group) = 0; | ||
| virtual void handle(const ExprGroup& expr_group) = 0; | ||
|
|
||
| void traverse(); | ||
|
|
||
| const ValGraph& graph() { | ||
| return val_graph_; | ||
| }; | ||
|
|
||
| private: | ||
| const ValGraph& val_graph_; | ||
| }; | ||
|
|
||
| // Statement sorting based on ValGraphVisitor, see warnings to ValGraph Visitor. | ||
| class ValGraphStmtSort : public ValGraphVisitor { | ||
| public: | ||
| ValGraphStmtSort(const ValGraph& val_graph) : ValGraphVisitor(val_graph) { | ||
| ValGraphVisitor::traverse(); | ||
| } | ||
|
|
||
| // Return non-reference so that code like below can work | ||
| // for (auto expr_group: ValGraphStmtSort(graph).exprs()) | ||
| ExprGroups exprs() const { | ||
| return sorted_exprs_; | ||
| } | ||
|
|
||
| ValGroups vals() const { | ||
| return sorted_vals_; | ||
| } | ||
|
|
||
| ~ValGraphStmtSort() override = default; | ||
|
|
||
| protected: | ||
| using ValGraphVisitor::handle; | ||
|
|
||
| void handle(const ValGroup& val_group) override { | ||
| sorted_vals_.pushBack(val_group); | ||
| } | ||
|
|
||
| void handle(const ExprGroup& expr_group) override { | ||
| sorted_exprs_.pushBack(expr_group); | ||
| } | ||
|
|
||
| ExprGroups sorted_exprs_; | ||
| ValGroups sorted_vals_; | ||
| }; | ||
|
|
||
| } // namespace nvfuser |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.