Conversation
|
Review updated until commit c01a3b8 Description
Changes walkthrough 📝
PR Reviewer Guide 🔍Here are some key observations to aid the review process:
|
|
!test --diff |
|
!test --diff |
|
!test --diff |
| EXPECT_EQ(tv6->getLoopDomain().back()->extent()->evaluate(), 2); | ||
| } | ||
|
|
||
| // The current analysis is not precise enough to pass this test |
There was a problem hiding this comment.
In this test, tv0 is resized in two different ways. The spanning-tree based analysis is not guaranteed to correctly identify the vectorization constraint.
The WAR when applied to this case is still too conservative.
|
!build |
jjsjann123
left a comment
There was a problem hiding this comment.
LGTM, glad to see that we are addressing the earlier comment
|
|
||
| // Check if vectorization is properly applied even when a resized ID | ||
| // is reachable from vectorized IDs. Pattern extracted from Litgpt | ||
| // LLama RoPE backward. |
There was a problem hiding this comment.
nitpick on comment.
This is the case where it's safe to skip the additional resize check. So that means the resized ID is NOT reachable from vectorized IDs.
| } | ||
| max_vec_size = std::gcd(max_vec_size, inferred_val.as<int64_t>()); | ||
| auto inferred_val_int = inferred_val.as<int64_t>(); | ||
| if (inferred_val_int == 0) { |
There was a problem hiding this comment.
this is for dynamic resize extents that would be 0?
| resize_in_out_groups.pushBack(graph.toGroup(resize->out())); | ||
| CanSkipResize bfs(graph, ref_groups, resize_in_out_groups, resize); | ||
| bfs.traverse(); | ||
| return bfs.allToNodesVisited(); |
There was a problem hiding this comment.
qq: here we are calling allToNodesVisited()? but the init function below has /*require_all_to_visited=*/false,, so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.
But the function name is somewhat confusing.
There was a problem hiding this comment.
No, the traversal should continue until no further progress is made. The require_all_to_visited flag means it's considered an error if not all of the to nodes were not able to reach.
Here, we just want to check all of the to nodes are reachable. It isn't an error even if not.
There was a problem hiding this comment.
so we are returning true here as long as a single node is visited in the target, which I think is the right behavior.
sorry I got confused myself. This returns true indicating it's safe to skip the check. So allToNodesVisited is the proper name for the function.
Thanks for elaborating on this one.
This is a follow-up to #3906, which added a WAR to #3640. While it's safe, it turned out it's just too conservative. For example, here's a concat pattern appearing in the backward of Litgpt Llama RoPE:
This is currently taken by the pointwise scheduler, which attempts to vectorize the innermost ID of the output (i.e.,
iS52{6144}). Since the resize ops of the three pad ops are reachable fromiS52, the WAR of #3640 simply takes them into consideration by calculating gcd with the left and right expand factors. In this case, since there's an expand factor of 1, the resulting vectorization factor is also just 1, which is clearly not what we want. Here, while the resized ID itself is not vectorizable due to the expand factor of 1, all of the resized tensors have large enough inner IDs that should allow the maximum vectorization.To make the WAR a little less conservative, this PR also checks if the constraint by a Resize expr may be missed by the vectorization analysis. In the above case, that should not happen as there's only one path through each of the resize-based tensor ops.
This change is still not able to eliminate false positives completely. See one of the new tests that is currently disabled.
The codediff results all seem to make sense. http://nv/eFb. Previously some of the tests did not have vectorization due to the WAR, which is relaxed in this PR and allows some vectorization.