Skip to content

Fix finding right non-divisible splits to predicate#2712

Merged
naoyam merged 8 commits intomainfrom
idmodel_indexing_non_divisible_fix
Jul 31, 2024
Merged

Fix finding right non-divisible splits to predicate#2712
naoyam merged 8 commits intomainfrom
idmodel_indexing_non_divisible_fix

Conversation

@naoyam
Copy link
Collaborator

@naoyam naoyam commented Jul 29, 2024

This is a bug fix for #2691, which isn't sufficient when broadcast tensors are involved.

@naoyam
Copy link
Collaborator Author

naoyam commented Jul 29, 2024

!build

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A lot of changes in this file are because I realized IdModel is also required to generate correct reference results. The main change is the newly added test.

for (const PredicateDomainInfo& pred_info :
non_divisible_split_predicates) {
IterDomain* non_divisible_domain = pred_info.id;
for (const auto& [eg, direction] : index_info.traversal_path) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pred_info has all the non-divisible splits of the consumer tensor, but they may not be the ones that should be added here since the actual indexing path may not include them. See the added test. The broadcast tensor, tv2, has a non-divisible split, but the real domain that needs to be predicated is the one showing up in its indexing path, which is the split in the tv3 tensor.

non_divisible_domain_stop_idx,
non_divisible_split_to_predicate->in()->extent()));

// TODO: Consolidating unswitch predicates isn't
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another oversight I've just realized. Will address in a follow-up PR.

@naoyam naoyam marked this pull request as ready for review July 29, 2024 22:07
@naoyam naoyam requested a review from jacobhinkle July 29, 2024 22:07
@naoyam naoyam marked this pull request as draft July 30, 2024 02:33
@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2024

Didn't realize the CI failures. They seem to be due to the consolidation problem. Moved this PR back to draft.

naoyam added 2 commits July 29, 2024 21:05
The existing logic to merge redundant unswitch conditions doesn't work
with the new indexer, resulting in duplicated conditions appearing in
the same unswitch predicate.
@naoyam naoyam force-pushed the idmodel_indexing_non_divisible_fix branch from 967685f to 7a62447 Compare July 30, 2024 15:23
@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2024

!build

std::back_inserter(all_parallelized_consumer_loop_ids),
[](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); });

// Identify which parallel type is used for which loop domain for
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this fix. UnswitchPredicateKey was mainly designed for supporting shift predicates. I'll clean do some cleanup later.

@naoyam naoyam added the idmodel label Jul 30, 2024
@naoyam naoyam marked this pull request as ready for review July 30, 2024 17:13
@naoyam
Copy link
Collaborator Author

naoyam commented Jul 30, 2024

Didn't realize the CI failures. They seem to be due to the consolidation problem. Moved this PR back to draft.

Added a fix for UnswitchPredicateKey. It's a bit ugly WAR as I don't want to break the existing usage of the class.

Copy link
Collaborator

@jacobhinkle jacobhinkle left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Comment on lines +87 to +91
// Here, the last two predicates are redundant since both of them
// guard the index with respect to the domain of extent 8, which is
// redundant. This is a bit annonying but should have no actual
// impact as the redundancy should be removed by the expression
// simplifier.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just wanted to see how this works for the given example. I believe we start with these two predicates:

for i0 : 2
  for i1 : ceilDiv(8, 3)
    for i2 : 3
       if (i2 + i1 * 3 < 8 && i2 + i1 * 3 + i0 * (ceilDiv(8, 3) * 3) < tv0.logical_shape[0]) {
         ...
       }

If so then I can see how the second condition can be simplified to i2 + i1 * 3 + i0 * 9 < 16 but using i0 < 2 and all loop variables non-negative this becomes i2 + i1 * 3 < 7, and that implies the other condition so we can discard that other condition. Cool.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For t0, since its logical domain is the one before the reshape, we would generate:

i2 + i1 * 3 + i0 * (ceilDiv(8, 3) * 3) < tv0.logical_shape[0]

and because of the ceilDiv, we would also generate:

i2 + i1 * 3 < 8

The latter is the non-divisible predicate for t0.

For t1, since its logical domain is [2, 8], we would generate:

i0 < 2
i2 + i1 * 3 < 8

And due to the ceilDiv, we would also generate:

i2 + i1 * 3 < 8

This is also interesting because reshape is always divisible, we can predicate [2, 8] for t0 too. That would probably help index hoisting.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation

naoyam and others added 4 commits July 31, 2024 12:09
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
@naoyam naoyam merged commit a66b528 into main Jul 31, 2024
@naoyam naoyam deleted the idmodel_indexing_non_divisible_fix branch July 31, 2024 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants