Fix unswitch predicate with merge inner path propagation#687
Conversation
|
!build |
|
!build |
|
Thanks for the quick fix and the detailed description. Even a monkey head like mine seems to see what's really going wrong with the codegen. One nitpick on the commit comment
I think you mean [3, 16], since that's the sizes used in the example below? |
Thanks. Fixed. |
| __LINE__, | ||
| __FILE__); | ||
|
|
||
| int64_t hidden_size = 1; |
| index_parameters.unswitched_domains.insert( | ||
| GpuLower::current()->caMap()->getConcreteMappedID( | ||
| loop_id, IdMappingMode::EXACT)); |
There was a problem hiding this comment.
Should we just do this inside the else {} clause above right after the loop_to_ind_map[loop] = extent - 1? This way, can we remove trackUnswitchedDomain?
There was a problem hiding this comment.
Not necessarily. We don't want to track vectorized domains or domains mapped with vectorized domains. They should still need to be filtered out even if this code is moved there.
Also, tracking may be necessary even if the loop is parallelized when the domain is extended for halo. I haven't tested such patterns, but placing this tracking here seems to make most sense to me.
There was a problem hiding this comment.
On a slightly related note, I'm not sure why I did this way:
https://github.com/NVIDIA/Fuser/pull/687/files/88183264b5ad608e6efd3af8adb22573eb269937#diff-2d7e9c403d33ef6dd38c80bfa051f541416532aabe7679112e7d3ca94c51f6e9R443-R444
It seems it should be loop_to_ind_map[loop] = loop->stop() - 1
There was a problem hiding this comment.
I think we don't have to handle IDs mapped to vectorization explicitly, because it will be automatically handled by the inner_extent % total_extent == 0 check in isModuloInvalidUnswitchedIndex, but still, let's just leave this PR as is and handle it explicitly, because why not?
There was a problem hiding this comment.
I remember the check was necessary to avoid some predicate changes. Let me check again.
There was a problem hiding this comment.
I think it was just because inner_extent in isModuloInvalidUnswitchedIndex may not be statically known, so
because it will be automatically handled by the inner_extent % total_extent == 0 check in isModuloInvalidUnswitchedIndex
may not hold.
| for (auto it = unswitched_domain_list.begin(); | ||
| it != unswitched_domain_list.end() - 1; | ||
| ++it) { |
There was a problem hiding this comment.
Am I understanding correct that the last item of the deque is always assigned index extent - 1?
There was a problem hiding this comment.
I think that should be true for most of the cases, but think about:
[i0, i1]
-> [i0, i1o, i1i] // split i1
-> [i0, i1i, i1o] // reorder
And unswitch only the last domain. The last item for the deque of i1 would be i1i, but since it's not unswitched, the index would not be extent -1.
I haven't thought about patterns like this. Wonder if this would break anything...
There was a problem hiding this comment.
For this case, shouldn't unswitched_domain_map_ be initialized as
{
i1o: {{i1o}}
}and updateUnswitchedDomains will update it as:
{
i1o: {{i1o}},
i1: {{i1i, i1o}}
}?
There was a problem hiding this comment.
Oh, sorry, you're right. I think this is correct:
the last item of the deque is always assigned index extent - 1
| if (simplifyExpr(IrBuilder::modExpr(inner_extent, total_extent)) | ||
| ->isZero()) { | ||
| continue; | ||
| } |
There was a problem hiding this comment.
Looks like there are two cases known to be safe, the first one correspond to distributeDivisibleDivMod in expr simplifier, which says (i * 128 + j) % 4 = j % 4, the second one correspond to distributeGcdRemainderDivMod in expr simplifier, which says if j < 128, then (i * 128 + j) % 1024 = (i * 128) % 1024 + j. For the first case, the assigned index for unswitched_domain_list.back() is i, and for the second case, the assigned index for unswitched_domain_list.back() is j. Am I understanding correct?
There was a problem hiding this comment.
I think that's correct.
There was a problem hiding this comment.
Could you add this note to the code? At least for me, if you say "the distributeDivisibleDivMod pattern", I would immediately get what it means, and any other way of explaining I would have to translate it in order to understand.
…through merge-inner or split-outer domains (#2689) (Stacked on top of #2677) The original issue is #681. It was addressed in #687. This PR is NOT as comprehensive as #687, but my gut feeling is that this should be good enough, in particular since contig indexing would avoid backward traversals through merge in many cases. I'll do final more comprehensive comparison with the legacy indexing once contig indexing is done. Since the original PR and issue were reviewed by @zasdfgbnm, could you please review this too?
Fixes #681 and #667
The issue of #681 is due to a bug in generating predicates for unswitched (or unrolled) domains. What unswitch does is, for example, when we have a kernel like:
where
PREDis a function to produce a predicate based on loop indexi. With unswitch, this kernel would look like:That is, if the maximum of
PRED(i)with0 <= i < Nis less than the extent, we can safely eliminate the predicate from the loop body, presumably increasing the performance. An important assumption here is that the maximum ofPRED(i)isPRED(N-1).The issue of #681 is that the assumption is not always valid when there's a merge operation between predicated domains to leaf domains because of the modulo used to propagate predicate indices through the merge inner path. For example, think about a 2D tensor of dimensions
[3, 16].And assume that only the innermost two leaf domains are unswitched. The initial index of the unswitched domains would be
2and4. The generated code would look like (for simplicity contig predication is ignored):Notice that the initial index of
2for the second leaf domain is not guaranteed to send the maximum index through the merge inner path. For example, wheni == 1, the value of the index math,((i * 3 + 2) % 4), is 1, which means the second predicate is true. If, instead of 2, an initial index of 0 were used for the second innermost domain, the value of the index would have been 3, so the second predicate is false. This means that picking 2 as the initial index of the second leaf domain does not generate a sufficient predicate for the second root domain. In fact, in this case, there's no single initial index value that can always generate sufficient predicates, so we would need to replace((i * 3 + 2) % 4)with 3 when back-traversing to the merge inner path, which is done here: https://github.com/NVIDIA/Fuser/pull/687/files#diff-625d71418720e0d8f49be94352457734eea3d6b372a44e53b1afd4484aad3d20R655-R659See the
UnswitchPredicateIssueRepro681test for a concrete example. The issue was originally discovered in a failing fusion as shown in theUnswitchPredicateIssue667, which was created by @jjsjann123 in #667.Obviously, while this correctness issue needs to be fixed, doing the above replacement could negate the benefit of unswitching. For example, in the above case, replacing
((i * 3 + 2) % 4)with 3 means that the second predicate is always false, so there's no opportunity of taking advantage of the unswitched path, and in fact the increased code complexity is likely to negatively impact the overall performance.Fortunately, this change has no impact in many cases as predication is done at merge output domains as long as contiguously merged. Since root domains are considered contiguous when generating predicates, as long as merge is done only with adjacent domains, all predicate expressions should just consist of predicates generated based on post-merge domains.
There are however several cases where we do non-contiguous merge. Specifically, in the transpose scheduler, a common transformation pattern is:
Here, the merge of
32, 32 -> 32 * 32is not contiguous, thus predicates at the root domains, i.e.,i1andi2, would be used, meaning predicate indexing would traverse through a modulo operation. Suppose the innermost three domains are unswitched, the unswitch predicate for thei2root domain would look like:i % (i2 / 32) * 32 + 31 < i2, whereiis the loop index of the outermost loop. Most commonly, the two innermost domains would be parallelized and vectorized as:Thus, the unswitch predicate with the current main branch would be:
i % (i2 / 32) * 32 + ((128 + threadIdx.x) * 4 + 3) % 32 < i2. With the above replacement, this predicate would be changed toi % (i2 / 32) * 32 + 31 < i2, which is much more restrictive than the current predicate, meaning less likely to be able to use the unswitched path.However, this replacement is not actually necessary in this case. Out of the three unswitched domains, only the outermost domain of
32 * 32 / 4 / 128is actually unswitched since the other two leaf domains have only one valid index value, i.e.,threadIdx.xand3, respectively. The true unswitched domain has two options for the initial index:0or1, however, it actually doesn't matter since it's always multiplied by128 * 4, meaning the modulo by 32 is always zero, so the unswitched domain has no contribution to the final root predicate. As a result, the replacement for the merge inner path can be omitted for this case.More specifically, see
trackUnswitchedDomainin csrc/device_lower/analysis/index_compute.cpp for which leaf domains to consider and csrc/index_compute.h,cpp for how they are used to extend the predicate indexing. I tried to make this analysis as precise as possible so that the replacement is only done when absolutely necessary. SeeisModuloInvalidUnswitchedIndexfor when we can declare the normal modulo-based propagation is safe for unswitch.With the analysis of
isModuloInvalidUnswitchedIndex, as far as I can see, no existing C++ tests and benchmarks are affected by this PR, meaning there should be no kernel performance change (although the lowering overhead is increased).