Conversation
Only includes minimum logic for basic indexing. Notably, no support for broadcast in this PR.
|
@jacobhinkle Please review this PR. Of course, feel free to ask any questions. As I mentioned above, this PR has lots of missing pieces but should work for indexing pointwise and reduction fusions like those used in the new unit tests. I'll continue preparing a next set of PRs. CC: @zasdfgbnm |
| if (!ir_utils::isTvOp(expr)) { | ||
| continue; | ||
| } | ||
| auto tv_output = ir_utils::getTvOutput(expr); |
There was a problem hiding this comment.
It's safe to look at only the first output because we require that any sibling tensors must be loop mapped together with one another in all dimensions since they're produced by the same expression right?
There was a problem hiding this comment.
Right, but that reminds me of the SDPA issue.
| case MemoryType::Local: | ||
| // Nothing is shared if it's Local | ||
| return false; | ||
| case MemoryType::Shared: | ||
| // Only TID parallelized domains are shared if it's Shared | ||
| return isParallelTypeThreadDim(parallel_type); | ||
| case MemoryType::Global: | ||
| // Only TID and BID parallelized domains are shared if it's Global | ||
| return isParallelTypeThreadDim(parallel_type) || | ||
| isParallelTypeBlockDim(parallel_type); |
There was a problem hiding this comment.
It seems to me that this is related to isPartitionedMemory since isSharedMemory(memory_type, parallel_type) == (isParallelTypeThread(parallel_type) || isParallelTypeDeviceDim(parallel_type)) && !isPartitionedMemory(memory_type, parallel_type). It might be nice to write it this way since it clearly indicates that at least for the "parallelized" dimensions above serial or vectorized the notion of memory partitioning is just the inverse of sharing.
jacobhinkle
left a comment
There was a problem hiding this comment.
LGTM. Thanks for the explanations.
Stacked on top of #2344. Adds support of broadcast indexing with loop promotion. The main change is just the use of promoted domains in loop and allocation domains.
This is the first of a series of IdModel-based indexing PRs. The overall algorithm is exactly what I presented at the nvfuser meeting and is based on the shortest path traversal introduced at #2270. This PR only includes minimum logic for basic indexing. However, pointwise and reduction ops with TID/BID parallelization as well as reshape should work.
Notable missing support:
Note that the new indexing interface is not plugged into the lowering yet. Currently, it's only used by the new unit tests. In subsequent PRs, there'll be a switch to use the new indexer instead of the existing one.
Overall tracking PR: #2238