Skip to content

Fix unswitch predicate with merge inner path propagation#687

Merged
naoyam merged 33 commits intomainfrom
fix_unswitch_pred_issue_681
Aug 15, 2023
Merged

Fix unswitch predicate with merge inner path propagation#687
naoyam merged 33 commits intomainfrom
fix_unswitch_pred_issue_681

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Aug 4, 2023

Fixes #681 and #667

The issue of #681 is due to a bug in generating predicates for unswitched (or unrolled) domains. What unswitch does is, for example, when we have a kernel like:

for (i = 0; i < N; ++i) {
   if (PRED(i) < ExtentOfSomeRootDomain && ...) {
      t1[i] = t0[i];
   }
}

where PRED is a function to produce a predicate based on loop index i. With unswitch, this kernel would look like:

if (PRED(N-1) < ExtentOfSomeRootDomain && ...) {
  for (i = 0; i < N; ++i) {
      t1[i] = t0[i];
  }
} else {
  for (i = 0; i < N; ++i) {
     if (PRED(i) < ExtentOfSomeRootDomain && ...) {
        t1[i] = t0[i];
     }
  }
}

That is, if the maximum of PRED(i) with 0 <= i < N is less than the extent, we can safely eliminate the predicate from the loop body, presumably increasing the performance. An important assumption here is that the maximum of PRED(i) is PRED(N-1).

The issue of #681 is that the assumption is not always valid when there's a merge operation between predicated domains to leaf domains because of the modulo used to propagate predicate indices through the merge inner path. For example, think about a 2D tensor of dimensions [3, 16].

[3, 16]
// split 16 by 5 
-> [3, 4, 5]
// merge 3 and 4
-> [3 * 4, 5]
// split 12 by 3
-> [4, 3, 5]

And assume that only the innermost two leaf domains are unswitched. The initial index of the unswitched domains would be 2 and 4. The generated code would look like (for simplicity contig predication is ignored):

for (i = 0; i < 4 ++i) {
  if ((i * 3 + 2) / 4 < 3 && ((i * 3 + 2) % 4) * 5 + 4 < 16) {
    for (j = 0; j < 3; ++j) {
      for (k = 0; k < 5; ++k) {
        some_math_op_using_the_tensor;
      }
    }
  } else {
    for (j = 0; j < 3; ++j) {
      for (k = 0; k < 5; ++k) {
        if ((i * 3 + j) / 4 < 3 && ((i * 3 + j) % 4) * 5 + k < 16) {
          some_math_op_using_the_tensor;
        }
      }
    }
  }
}    

Notice that the initial index of 2 for the second leaf domain is not guaranteed to send the maximum index through the merge inner path. For example, when i == 1, the value of the index math, ((i * 3 + 2) % 4), is 1, which means the second predicate is true. If, instead of 2, an initial index of 0 were used for the second innermost domain, the value of the index would have been 3, so the second predicate is false. This means that picking 2 as the initial index of the second leaf domain does not generate a sufficient predicate for the second root domain. In fact, in this case, there's no single initial index value that can always generate sufficient predicates, so we would need to replace ((i * 3 + 2) % 4) with 3 when back-traversing to the merge inner path, which is done here: https://github.com/NVIDIA/Fuser/pull/687/files#diff-625d71418720e0d8f49be94352457734eea3d6b372a44e53b1afd4484aad3d20R655-R659

See the UnswitchPredicateIssueRepro681 test for a concrete example. The issue was originally discovered in a failing fusion as shown in the UnswitchPredicateIssue667, which was created by @jjsjann123 in #667.

Obviously, while this correctness issue needs to be fixed, doing the above replacement could negate the benefit of unswitching. For example, in the above case, replacing ((i * 3 + 2) % 4) with 3 means that the second predicate is always false, so there's no opportunity of taking advantage of the unswitched path, and in fact the increased code complexity is likely to negatively impact the overall performance.

Fortunately, this change has no impact in many cases as predication is done at merge output domains as long as contiguously merged. Since root domains are considered contiguous when generating predicates, as long as merge is done only with adjacent domains, all predicate expressions should just consist of predicates generated based on post-merge domains.

There are however several cases where we do non-contiguous merge. Specifically, in the transpose scheduler, a common transformation pattern is:

[i1, i2]
-> [i1 / 32, 32, i2 / 32, 32]
-> [i1 / 32 * i2 / 32, 32 * 32]
-> [i1 / 32 * i2 / 32, 32 * 32 / 4, 4]
-> [i1 / 32 * i2 / 32, 32 * 32 / 4 / 128, 128, 4]

Here, the merge of 32, 32 -> 32 * 32 is not contiguous, thus predicates at the root domains, i.e., i1 and i2, would be used, meaning predicate indexing would traverse through a modulo operation. Suppose the innermost three domains are unswitched, the unswitch predicate for the i2 root domain would look like: i % (i2 / 32) * 32 + 31 < i2, where i is the loop index of the outermost loop. Most commonly, the two innermost domains would be parallelized and vectorized as:

[i1 / 32 * i2 / 32, 32 * 32 / 4 / 128, threadIdx.x(128), vec(4)]

Thus, the unswitch predicate with the current main branch would be: i % (i2 / 32) * 32 + ((128 + threadIdx.x) * 4 + 3) % 32 < i2. With the above replacement, this predicate would be changed to i % (i2 / 32) * 32 + 31 < i2, which is much more restrictive than the current predicate, meaning less likely to be able to use the unswitched path.

However, this replacement is not actually necessary in this case. Out of the three unswitched domains, only the outermost domain of 32 * 32 / 4 / 128 is actually unswitched since the other two leaf domains have only one valid index value, i.e., threadIdx.x and 3, respectively. The true unswitched domain has two options for the initial index: 0 or 1, however, it actually doesn't matter since it's always multiplied by 128 * 4, meaning the modulo by 32 is always zero, so the unswitched domain has no contribution to the final root predicate. As a result, the replacement for the merge inner path can be omitted for this case.

More specifically, see trackUnswitchedDomain in csrc/device_lower/analysis/index_compute.cpp for which leaf domains to consider and csrc/index_compute.h,cpp for how they are used to extend the predicate indexing. I tried to make this analysis as precise as possible so that the replacement is only done when absolutely necessary. See isModuloInvalidUnswitchedIndex for when we can declare the normal modulo-based propagation is safe for unswitch.

With the analysis of isModuloInvalidUnswitchedIndex, as far as I can see, no existing C++ tests and benchmarks are affected by this PR, meaning there should be no kernel performance change (although the lowering overhead is increased).

@naoyam
Copy link
Collaborator Author

naoyam commented Aug 8, 2023

!build

@naoyam
Copy link
Collaborator Author

naoyam commented Aug 9, 2023

!build

@naoyam naoyam changed the title WIP: Fix unswitch predicate Fix unswitch predicate with merge inner path propagation Aug 9, 2023
@naoyam naoyam marked this pull request as ready for review August 9, 2023 05:42
@naoyam naoyam requested a review from zasdfgbnm August 9, 2023 05:44
@jjsjann123
Copy link
Collaborator

Thanks for the quick fix and the detailed description. Even a monkey head like mine seems to see what's really going wrong with the codegen.

One nitpick on the commit comment

For example, think about a 2D tensor of dimensions [3, 40].

[3, 16]
// split 16 by 5 
-> [3, 4, 5]
// merge 3 and 4
-> [3 * 4, 5]
// split 12 by 3
-> [4, 3, 5]

I think you mean [3, 16], since that's the sizes used in the example below?

@naoyam
Copy link
Collaborator Author

naoyam commented Aug 9, 2023

Thanks for the quick fix and the detailed description. Even a monkey head like mine seems to see what's really going wrong with the codegen.

One nitpick on the commit comment

For example, think about a 2D tensor of dimensions [3, 40].

[3, 16]
// split 16 by 5 
-> [3, 4, 5]
// merge 3 and 4
-> [3 * 4, 5]
// split 12 by 3
-> [4, 3, 5]

I think you mean [3, 16], since that's the sizes used in the example below?

Thanks. Fixed.

__LINE__,
__FILE__);

int64_t hidden_size = 1;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused variable

Comment on lines +464 to +466
index_parameters.unswitched_domains.insert(
GpuLower::current()->caMap()->getConcreteMappedID(
loop_id, IdMappingMode::EXACT));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just do this inside the else {} clause above right after the loop_to_ind_map[loop] = extent - 1? This way, can we remove trackUnswitchedDomain?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessarily. We don't want to track vectorized domains or domains mapped with vectorized domains. They should still need to be filtered out even if this code is moved there.

Also, tracking may be necessary even if the loop is parallelized when the domain is extended for halo. I haven't tested such patterns, but placing this tracking here seems to make most sense to me.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On a slightly related note, I'm not sure why I did this way:
https://github.com/NVIDIA/Fuser/pull/687/files/88183264b5ad608e6efd3af8adb22573eb269937#diff-2d7e9c403d33ef6dd38c80bfa051f541416532aabe7679112e7d3ca94c51f6e9R443-R444

It seems it should be loop_to_ind_map[loop] = loop->stop() - 1

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we don't have to handle IDs mapped to vectorization explicitly, because it will be automatically handled by the inner_extent % total_extent == 0 check in isModuloInvalidUnswitchedIndex, but still, let's just leave this PR as is and handle it explicitly, because why not?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember the check was necessary to avoid some predicate changes. Let me check again.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it was just because inner_extent in isModuloInvalidUnswitchedIndex may not be statically known, so

because it will be automatically handled by the inner_extent % total_extent == 0 check in isModuloInvalidUnswitchedIndex

may not hold.

Comment on lines +486 to +488
for (auto it = unswitched_domain_list.begin();
it != unswitched_domain_list.end() - 1;
++it) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I understanding correct that the last item of the deque is always assigned index extent - 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be true for most of the cases, but think about:

[i0, i1]
-> [i0, i1o, i1i] // split i1
-> [i0, i1i, i1o] // reorder

And unswitch only the last domain. The last item for the deque of i1 would be i1i, but since it's not unswitched, the index would not be extent -1.

I haven't thought about patterns like this. Wonder if this would break anything...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this case, shouldn't unswitched_domain_map_ be initialized as

{ 
  i1o: {{i1o}}
}

and updateUnswitchedDomains will update it as:

{
  i1o: {{i1o}},
  i1: {{i1i, i1o}}
}

?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, sorry, you're right. I think this is correct:

the last item of the deque is always assigned index extent - 1

if (simplifyExpr(IrBuilder::modExpr(inner_extent, total_extent))
->isZero()) {
continue;
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there are two cases known to be safe, the first one correspond to distributeDivisibleDivMod in expr simplifier, which says (i * 128 + j) % 4 = j % 4, the second one correspond to distributeGcdRemainderDivMod in expr simplifier, which says if j < 128, then (i * 128 + j) % 1024 = (i * 128) % 1024 + j. For the first case, the assigned index for unswitched_domain_list.back() is i, and for the second case, the assigned index for unswitched_domain_list.back() is j. Am I understanding correct?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's correct.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add this note to the code? At least for me, if you say "the distributeDivisibleDivMod pattern", I would immediately get what it means, and any other way of explaining I would have to translate it in order to understand.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the fix!

@naoyam naoyam merged commit 1c0cdb2 into main Aug 15, 2023
@naoyam naoyam deleted the fix_unswitch_pred_issue_681 branch August 15, 2023 06:19
naoyam added a commit that referenced this pull request Jul 27, 2024
…through merge-inner or split-outer domains (#2689)

(Stacked on top of #2677)

The original issue is #681. It was addressed in #687.

This PR is NOT as comprehensive as #687, but my gut feeling is that this
should be good enough, in particular since contig indexing would avoid
backward traversals through merge in many cases.

I'll do final more comprehensive comparison with the legacy indexing
once contig indexing is done.

Since the original PR and issue were reviewed by @zasdfgbnm, could you
please review this too?
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Wrong unswitch predicate

3 participants