Handling allocation domain of the input TensorViews in the matmul scheduler #2309
Handling allocation domain of the input TensorViews in the matmul scheduler #2309
Conversation
jacobhinkle
left a comment
There was a problem hiding this comment.
It seems there are a few changes here:
- Don't propagate allocation domain when using
cacheAfteron smem buffers, since we need the loaded register buffers to have allocation domains matching their root domains. - Change
scheduleLdMatrixto check consumer/producer innermost allocation ID to see whether transpose is needed. Previously this was signalled by theLoadStoreOpTypeon that op. - Use allocation domain instead of root domain for
orderTiledConcreteIdAsRoot. This is called only on shared memory TVs, and we now have possibly non-trivial allocation domains on those tensors.
I think this generally seems fine. I do have a question that we can address in the future: how should we handle cases where there is a transposed operand which has a prologue that comes before the transpose? It seems that we still rely on using the smem->register load for transposing but in such a case that will come after the prologue.
csrc/scheduler/mma_utils.cpp
Outdated
| {consumer->getMaybeAllocationDomain().back()}); | ||
|
|
||
| auto ids = ir_utils::filterByType<IterDomain>(vals); | ||
| auto idsOnPath = std::vector<IterDomain*>(ids.begin(), ids.end()); |
There was a problem hiding this comment.
I don't think this line is needed is it? Just use ids instead of idsOnPath. Also a nit: const on all these variables.
There was a problem hiding this comment.
I sort of based it on this:
Lines 738 to 748 in 5e0c89b
// Filter so we only have iteration domains (ignore Ints used in split)
auto all_ids = ir_utils::filterByType(all_vals);
return std::vector<IterDomain*>(all_ids.begin(), all_ids.end());
csrc/scheduler/mma_utils.cpp
Outdated
| // Get all the IDs from the innermost ID of the allocation domain of | ||
| // the consumer to the root domain of the consumer. | ||
| auto vals = DependencyCheck::getAllValsBetween( | ||
| {consumer->getRootDomain().begin(), consumer->getRootDomain().end()}, |
There was a problem hiding this comment.
Do you need to filter out broadcast and reduction domains?
5e0c89b to
32f15f7
Compare
|
This PR needs rebase so that changes in #2315 is excluded from the diff of this PR. |
32f15f7 to
7b0fd07
Compare
|
!build |
|
!build |
|
!build |
…eduler (#2309) In this PR we extend the matmul scheduler to support inputs with allocation domains. To the fusion (with inputs tv_a and tv_b), we add two LoadStoreOps to both inputs. The first Op corresponds to a load to shared memory, where we propagate the allocation domain. The second op corresponds to reading to registers, where we don't propagate the allocation domain since the scheduler takes charge of setting the allocation domain in the registers. Based on the difference in the (maybe)allocation domain of the producer and consumer of the second LoadStoreOp, we may do transposed load when reading to registers.  See also #2315.
In this PR we extend the matmul scheduler to support inputs with allocation domains.
To the fusion (with inputs tv_a and tv_b), we add two LoadStoreOps to both inputs.
The first Op corresponds to a load to shared memory, where we propagate the allocation domain. The second op corresponds to reading to registers, where we don't propagate the allocation domain since the scheduler takes charge of setting the allocation domain in the registers. Based on the difference in the (maybe)allocation domain of the producer and consumer of the second LoadStoreOp, we may do transposed load when reading to registers.
See also #2315.