Support indexing of DIDx parallelized tensors#2364
Conversation
Only includes minimum logic for basic indexing. Notably, no support for broadcast in this PR.
jacobhinkle
left a comment
There was a problem hiding this comment.
LGTM. Ack about using isMemoryPartitionedAccess, but I think we're alread assuming the sibling outputs share the same for-loops (see line 347), so maybe we should assert that all siblings have same memory type and number of leaf IDs.
| // should be used, but that means we would need to consider | ||
| // multiple outputs with different memory types, though it | ||
| // should be uncommon in practice. | ||
| shouldUseZeroIndex(loop_group) || isParallelTypeDeviceDim(ptype)) { |
There was a problem hiding this comment.
Could isParallelTypeDeviceDim(ptype) go inside shouldUseZeroIndex? If any ID in the group is parallelized DID then the loop must be trivial right?
|
Not having to split the logical shape for DID is wonderful. For my education, what are the next steps so we can benefit from this work? I assume this PR fixes IdModel to allow leaf-only DID split, but none/few schedulers use IdModel. |
That could be a reasonable option, but supporting different memory types may be trivial. At least I'd give it a try. |
I'll soon have a PR to integrate this new indexer into lowering. Something like this: |
Makes sense for device lowering. My concern was about the schedulers not yet using IdModel. Is IdModel required to allow the schedulers to handle leaf-only DID split? |
At this moment, no, I don't think so. |
Stacked on top of #2353
Small changes to allow indexing of tensors with DIDx domains.
CC: @zasdfgbnm @cowanmeg @samnordmann @wujingyue