Visit all Expr outputs during concretization#591
Conversation
This does not test execution, which is hitting a separate issue.
It would be better to address this in StmtSort/IterVisitor
csrc/dynamic_transform.cpp
Outdated
| if (tv->definition()) { | ||
| for (auto outp : | ||
| ir_utils::filterByType<TensorView>(tv->definition()->outputs())) { | ||
| if (visited_tvs.find(outp) == visited_tvs.end()) { | ||
| mutate(outp); | ||
| visited_tvs.insert(tv); | ||
| } | ||
| } | ||
| } else { | ||
| if (visited_tvs.find(tv) == visited_tvs.end()) { | ||
| mutate(tv); | ||
| visited_tvs.insert(tv); | ||
| } | ||
| } |
There was a problem hiding this comment.
This is awkward and inefficient, and this functionally should probably be implemented as an option to StmtSort::getStmts instead.
There was a problem hiding this comment.
Pushed an alternative that adds a traverse_siblings option for IterVisitor and StmtSort.
csrc/scheduler/utils.cpp
Outdated
|
|
||
| auto replay_exprs = | ||
| StmtSort::getExprs(tv->fusion(), {reference_id}, false, false); | ||
| StmtSort::getExprs(tv->fusion(), {reference_id}, false, false, false); |
There was a problem hiding this comment.
This was hitting the wrong overload and trying to convert {reference_id} to bool.
There was a problem hiding this comment.
Which one are you referring to?
There was a problem hiding this comment.
Oh, I realized there's a new overload in this PR
There was a problem hiding this comment.
This makes me concerned as it's easy to make this mistake. How about changing the name to avoid the overload? Maybe getExprs and getExprsTo?
There was a problem hiding this comment.
Great idea. I'll rename it.
|
!build |
| if (traverse_siblings) { | ||
| // Add unvisited siblings to next_stmts | ||
| std::vector<Statement*> unvisited_sibs; | ||
| for (auto next_val : ir_utils::filterByType<Val>(next_stmts)) { | ||
| for (auto sib : ir_utils::siblingValsOf(next_val)) { | ||
| if (visited.find(sib) == visited.end()) { | ||
| // Push to separate vector so that we don't modify next_stmts | ||
| // while looping | ||
| unvisited_sibs.push_back(sib); | ||
| } | ||
| } | ||
| } | ||
| next_stmts.insert( | ||
| next_stmts.end(), unvisited_sibs.begin(), unvisited_sibs.end()); | ||
| } |
There was a problem hiding this comment.
A couple of questions.
- Do we want to visit siblings even when they are in
visitediftraverse_all_pathsis true? - More commonly, though, since when we are visiting an input first, if that has a sibling, it would be most likely not visited yet, no matter if it's an orphaned sibling or not. Wouldn't this change the traversal ordering then? It would still be a topological order, but I'd be surprised if
traverse_siblingswould change the traversal order of a fusion with no orphaned siblings. - Can you please add unit tests just for this traversal with siblings?
There was a problem hiding this comment.
- Yeah, we should add them even if visited here I believe.
- You're right that this flag changes the ordering. An alternative that would not change the ordering would be to add siblings to a vector called
maybe_orphaned_siblingshere, then after thewhile (!stmt_stack.empty()) {}loop we would loop overmaybe_orphaned_siblingsand call handle on those that are not found invisited. This would use some extra space to hold the vector of siblings, but would perform the same number of searches invisited. - Yes. I will add some tests for the generic traversal.
There was a problem hiding this comment.
Re 2: The alternative sounds simpler and better to me. The extra space would be negligible.
There was a problem hiding this comment.
Pushed changes that address 1 and 2 above. Will add tests soon
There was a problem hiding this comment.
Pushed tests, and fixed a bug causing siblings of outputs to not be traversed.
This prevents accidental use of the wrong overload, wherein a vector of Vals might be converted to bool.
This is still a topological ordering since these siblings have no active uses (by definition), and this method prevents changing the traversal ordering by changing this flag when there are no orphaned siblings in the Fusion.
|
!build |
| stmts = StmtSort::getStmtsTo( | ||
| &fusion, | ||
| {wf.n}, | ||
| /*traverse_all_paths*/ false, | ||
| /*traverse_attributes*/ false, | ||
| /*traverse_siblings*/ true); |
There was a problem hiding this comment.
This test exposed a bug where siblings of outputs in to were not being traversed. Fixed now.
naoyam
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the fix and overall improvement.
This change ensures that concretization visits all outputs to active expressions. Even though unused outputs do not directly affect the result of the Fusion, they still must not have Symbolic IterDomains during scheduling. Note that this pre-empts the more complicated #420.
This fixes the immediate cause of #418. However, that test still fails due to another issue where scalars get lost during segmentation. The test included here only tests that concretization works properly; scheduling and execution will be tested in another PR.