Support allocation domain in the matmul scheduler#2226
Support allocation domain in the matmul scheduler#2226
Conversation
csrc/scheduler/matmul.cpp
Outdated
| bool hasInnerTranspose(TensorView* out_tv) { | ||
| auto use_root_or_alloc = out_tv->hasAllocation() | ||
| ? out_tv->getAllocationDomain() | ||
| : out_tv->getRootDomain(); | ||
| return use_root_or_alloc.back() != out_tv->getMaybeRFactorDomain().back(); | ||
| } | ||
|
|
There was a problem hiding this comment.
See my note above. I think the existence of this function is a sign that we should think more about producer allocation to consumer allocation comparison instead of the root to rfactor. For example, if we are doing a transpose op, but we set the allocation domain of the consumer to match that of the producer, then we should not consider that to have a transpose.
There was a problem hiding this comment.
Please take a look at the modified implementation.
QQ: is the allocation domain of |
feb3b4d to
0273cc0
Compare
This does run all the 4 cases. I'll update the comment/description of the PR. |
|
!build |
csrc/scheduler/matmul.cpp
Outdated
| // A LoadMatrixTranspose is needed is there's a rfactor which doesn't | ||
| // match the alloc/root domain. If there's no rfactor, then we | ||
| // need a transpose if the alloc domain and root don't match. When we say | ||
| // match, we refer to the innermost iter domain. |
There was a problem hiding this comment.
I think this assumes the allocation domain of the producer matches its rfactor domain (and out_tvs root).
tests/cpp/test_matmul_scheduler.cpp
Outdated
| if (tv1->hasAllocation()) { | ||
| tv1t->setAllocationDomain({tv1t->axis(0), tv1t->axis(1)}, true); | ||
| } |
There was a problem hiding this comment.
I think the scheduler should actually be the one setting the allocation domain for this tensor. It would set the allocation domain on the register tensor that is the argument to the MmaOp since that is where the TN requirement comes from, then propagate allocation domain back to the register tensor that gets loaded from smem (i.e. the ldmatrix output).
| // For Turing and Ampere, the layout of the MmaOp is always TN | ||
| NVF_ERROR( | ||
| mma_layout == MmaLayout::TN, | ||
| "MMAs in Turing and Ampere are TN only, transpose is handled either " | ||
| "via ldmatrix.trans for fp16 or explicitly for other types."); | ||
|
|
There was a problem hiding this comment.
I think we could use a utility propagateMmaProducerLayouts(MmaOp* mma_op, const std::vector<TensorView*>& to) which sets the operand allocation domains for the MmaOp based on its macro, then propagates that upstream to the given tensors (inclusive). We'd call it like propagateMmaProducerLayouts(mma_op, {acr, bcr}); I think.
jjsjann123
left a comment
There was a problem hiding this comment.
LGTM. I'll let @jacobhinkle stamp.
| splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize); | ||
| } | ||
|
|
||
| bool needsTranposedLoad(TensorView* producer, TensorView* consumer) { |
There was a problem hiding this comment.
| bool needsTranposedLoad(TensorView* producer, TensorView* consumer) { | |
| bool needsTransposedLoad(TensorView* producer, TensorView* consumer) { |
| {std::make_pair(acw_smem, &acr), std::make_pair(bcw_smem, &bcr)}) { | ||
| auto producer = tv_smem; | ||
| auto consumer = tv_smem->uses().at(0)->output(0)->as<TensorView>(); | ||
| auto toTranspose = needsTranposedLoad(producer, consumer); |
There was a problem hiding this comment.
| auto toTranspose = needsTranposedLoad(producer, consumer); | |
| auto toTranspose = needsTransposedLoad(producer, consumer); |
| const auto map = | ||
| PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer(); | ||
| auto maybeProducerAlloc = producer->getMaybeAllocationDomain(); | ||
| auto maybeConsumerRFactor = consumer->getMaybeRFactorDomain(); |
There was a problem hiding this comment.
nitpick: let's assert on std::is_permutation. Otherwise we can falsely claim needsTransposedLoad.
| *tv_r = ldst->out()->as<TensorView>(); | ||
| ldst->setOpType( | ||
| toTranspose ? LoadStoreOpType::LdMatrixTranspose | ||
| : LoadStoreOpType::LdMatrix); |
There was a problem hiding this comment.
Out of curiosity. What's planned to change for the follow up PR when you say?
In a follow-up PR I can handle LdMatrix/LdMatrixTrans.
There was a problem hiding this comment.
This is referring to us inferring the ptx instruction based on allocation domain. I think we need to look at consumer alloc to consumer root and compose that with producer rfactor to producer alloc. The resulting composition can tell us whether a transpose is needed as it combines both the allocation information and the semantic information in consumer's root to rfactor domain.
There was a problem hiding this comment.
Thanks for filling in that~~ 🙇
| @@ -57,11 +57,11 @@ void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes); | |||
|
|
|||
| //! Order the inner tile dimensions as the original order in | |||
| //! root domain. Also putting broadcast domains on the left. | |||
There was a problem hiding this comment.
in the original order in allocation domain instead?
| auto maybe_root = | ||
| getMaybeRootIfInnermostTiled(leaf_id, maybe_rfactor_id_set); | ||
| auto maybe_alloc_domain = | ||
| getMaybeAllocationIfInnermostTiled(leaf_id, id_set); |
There was a problem hiding this comment.
nitpick: I think the logic here assumes rfactor and allocation are just permutation? Should we add an assert somewhere?
There was a problem hiding this comment.
So I guess if we are doing that somewhere here (or maybe it's already done in an even earlier spot, please ignore the comment on the other place where I asked to assert the permutation.
|
Thanks for the reviews, but given our design has evolved. I'll create a new PR. |
This replaces the `CombineMulSum` class with `MatmulPattern` in the Matmul scheduler. Additionally, we use these matmul patterns to determine the problem layout, IterDomain roles, and TensorView roles. The allocation domain is used to determine the problem layout. The matmul scheduler is updated to reject segments whose input allocation domains are non-trivial (until that is supported eg. by #2226). Note that this does not add handling of `MatmulOp` and `LinearOp` in the matmul scheduler. That will be done next in #2236 or similar. --------- Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
In this discussion when we say the matmul scheduler, we’ll mean the scheduleMatmul function.
We will also assume that the implementation of the above is tied to the Matmul and Linear Op recently developed. In this note, we’ll focus on the matmul op.
As shown in the figure above, we will convert the matmul op, which takes in inputs A and B with root domains [M, K] and [M, K] respectively, to a collection of transpose, broadcasts and Mma ops. Please note that the inputs A and B can both have allocation domains (transposed memory layouts - we do not place restrictions on that).
The task at hand focuses on how this collection of ops is handled in the scheduleMatmul function when the inputs have allocation domains. The table below outlines the configuration of inputs A and B we want to support.
Discussion with @jacobhinkle and @zasdfgbnm led us to identifying the following issues to get the above cases running:
This was solved by Xiang a while back in commit [PR].
(626a405).
This refers to the bit of code here.
We have to modify how we set LdMatrix vs. LdMatrix based not only on the comparison between rfactorDomain and root domain, but also takes into account the allocation domains of both shared memory or A and B. What I took this to mean is, for the shared memory TV of B, bcw_smem, since it's transposed, we will compare it's rfactor domain and alloction domain or the root domain (if allocation domain isn't there) to decide whether to transpose. And for input A, acw_smem, we'll compare the allocation domain and the root domain to see if we need to transpose. Please correct me if I have this wrong.
@jacobhinkle also pointed out that optionally, we may get rid of the enum LdMatrix/LdMatrixTrans and handle the transpose when lowering.
In this PR, I prototyped the steps mentioned in the first three bullets above.
The code here is able to handle all the cases mentioned in the table above.
In a follow-up PR I can handle LdMatrix/LdMatrixTrans.