Fix finding right non-divisible splits to predicate#2712
Conversation
|
!build |
There was a problem hiding this comment.
A lot of changes in this file are because I realized IdModel is also required to generate correct reference results. The main change is the newly added test.
| for (const PredicateDomainInfo& pred_info : | ||
| non_divisible_split_predicates) { | ||
| IterDomain* non_divisible_domain = pred_info.id; | ||
| for (const auto& [eg, direction] : index_info.traversal_path) { |
There was a problem hiding this comment.
pred_info has all the non-divisible splits of the consumer tensor, but they may not be the ones that should be added here since the actual indexing path may not include them. See the added test. The broadcast tensor, tv2, has a non-divisible split, but the real domain that needs to be predicated is the one showing up in its indexing path, which is the split in the tv3 tensor.
tests/cpp/test_indexing.cpp
Outdated
| non_divisible_domain_stop_idx, | ||
| non_divisible_split_to_predicate->in()->extent())); | ||
|
|
||
| // TODO: Consolidating unswitch predicates isn't |
There was a problem hiding this comment.
This is another oversight I've just realized. Will address in a follow-up PR.
|
Didn't realize the CI failures. They seem to be due to the consolidation problem. Moved this PR back to draft. |
The existing logic to merge redundant unswitch conditions doesn't work with the new indexer, resulting in duplicated conditions appearing in the same unswitch predicate.
967685f to
7a62447
Compare
|
!build |
| std::back_inserter(all_parallelized_consumer_loop_ids), | ||
| [](IterDomain* x) { return isParallelTypeThread(x->getParallelType()); }); | ||
|
|
||
| // Identify which parallel type is used for which loop domain for |
There was a problem hiding this comment.
Added this fix. UnswitchPredicateKey was mainly designed for supporting shift predicates. I'll clean do some cleanup later.
Added a fix for |
| // Here, the last two predicates are redundant since both of them | ||
| // guard the index with respect to the domain of extent 8, which is | ||
| // redundant. This is a bit annonying but should have no actual | ||
| // impact as the redundancy should be removed by the expression | ||
| // simplifier. |
There was a problem hiding this comment.
I just wanted to see how this works for the given example. I believe we start with these two predicates:
for i0 : 2
for i1 : ceilDiv(8, 3)
for i2 : 3
if (i2 + i1 * 3 < 8 && i2 + i1 * 3 + i0 * (ceilDiv(8, 3) * 3) < tv0.logical_shape[0]) {
...
}
If so then I can see how the second condition can be simplified to i2 + i1 * 3 + i0 * 9 < 16 but using i0 < 2 and all loop variables non-negative this becomes i2 + i1 * 3 < 7, and that implies the other condition so we can discard that other condition. Cool.
There was a problem hiding this comment.
For t0, since its logical domain is the one before the reshape, we would generate:
i2 + i1 * 3 + i0 * (ceilDiv(8, 3) * 3) < tv0.logical_shape[0]
and because of the ceilDiv, we would also generate:
i2 + i1 * 3 < 8
The latter is the non-divisible predicate for t0.
For t1, since its logical domain is [2, 8], we would generate:
i0 < 2
i2 + i1 * 3 < 8
And due to the ceilDiv, we would also generate:
i2 + i1 * 3 < 8
This is also interesting because reshape is always divisible, we can predicate [2, 8] for t0 too. That would probably help index hoisting.
There was a problem hiding this comment.
Thanks for the explanation
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com>
This is a bug fix for #2691, which isn't sufficient when broadcast tensors are involved.