Skip to content

Support allocation domain in the matmul scheduler#2226

Closed
protonu wants to merge 11 commits intomainfrom
pbasu_experiment_handle_mkn_k_unitstride
Closed

Support allocation domain in the matmul scheduler#2226
protonu wants to merge 11 commits intomainfrom
pbasu_experiment_handle_mkn_k_unitstride

Conversation

@protonu
Copy link
Collaborator

@protonu protonu commented May 9, 2024

image

In this discussion when we say the matmul scheduler, we’ll mean the scheduleMatmul function.

We will also assume that the implementation of the above is tied to the Matmul and Linear Op recently developed. In this note, we’ll focus on the matmul op.

As shown in the figure above, we will convert the matmul op, which takes in inputs A and B with root domains [M, K] and [M, K] respectively, to a collection of transpose, broadcasts and Mma ops. Please note that the inputs A and B can both have allocation domains (transposed memory layouts - we do not place restrictions on that).

The task at hand focuses on how this collection of ops is handled in the scheduleMatmul function when the inputs have allocation domains. The table below outlines the configuration of inputs A and B we want to support.

Root Domain A Root Domain B Allocation Domain A Allocation Domain B
[M, K] [K, N] [] []
[M, K] [K, N] [] [N, K]
[M, K] [K, N] [K, M] []
[M, K] [K, N] [K, M] [N, K]

Discussion with @jacobhinkle and @zasdfgbnm led us to identifying the following issues to get the above cases running:

  • Propagate allocation domain from gmem operands to their smem consumers.
    This was solved by Xiang a while back in commit [PR].
    (626a405).
  • Update mma_utils::orderTiledConcreteIdAsRoot to use allocation domain instead of root.
  • Detect transpose properly in scheduleLdMatrix
    This refers to the bit of code here.
    We have to modify how we set LdMatrix vs. LdMatrix based not only on the comparison between rfactorDomain and root domain, but also takes into account the allocation domains of both shared memory or A and B. What I took this to mean is, for the shared memory TV of B, bcw_smem, since it's transposed, we will compare it's rfactor domain and alloction domain or the root domain (if allocation domain isn't there) to decide whether to transpose. And for input A, acw_smem, we'll compare the allocation domain and the root domain to see if we need to transpose. Please correct me if I have this wrong.
  • Update inline_ptx.cpp to detect when to append .trans instead of using a different LoadStoreOpType (optional)
    @jacobhinkle also pointed out that optionally, we may get rid of the enum LdMatrix/LdMatrixTrans and handle the transpose when lowering.

In this PR, I prototyped the steps mentioned in the first three bullets above.
The code here is able to handle all the cases mentioned in the table above.
In a follow-up PR I can handle LdMatrix/LdMatrixTrans.

@protonu protonu changed the title [do not land]getting (m,k,n) case with unit stride K running [do not land] This runs matmuls where rfactors for A and B are [M, K] and [K, N] respectively. May 9, 2024
Comment on lines +737 to +747
bool hasInnerTranspose(TensorView* out_tv) {
auto use_root_or_alloc = out_tv->hasAllocation()
? out_tv->getAllocationDomain()
: out_tv->getRootDomain();
return use_root_or_alloc.back() != out_tv->getMaybeRFactorDomain().back();
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my note above. I think the existence of this function is a sign that we should think more about producer allocation to consumer allocation comparison instead of the root to rfactor. For example, if we are doing a transpose op, but we set the allocation domain of the consumer to match that of the producer, then we should not consider that to have a transpose.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please take a look at the modified implementation.

@jjsjann123
Copy link
Collaborator

This runs matmuls where Rfactors for A and B are [M, K] and [K, N] respectively.

There are two versions of this:

  1. B has an allocation domain of [N, K]
  2. B has an allocation domain of [K, N]

QQ: is the allocation domain of A not impacting kernel generation? I thought we are looking at their combination which would be 4 cases.

@protonu protonu force-pushed the pbasu_experiment_handle_mkn_k_unitstride branch from feb3b4d to 0273cc0 Compare May 15, 2024 15:02
@protonu
Copy link
Collaborator Author

protonu commented May 15, 2024

QQ: is the allocation domain of A not impacting kernel generation? I thought we are looking at their combination which would be 4 cases.

This does run all the 4 cases. I'll update the comment/description of the PR.

@protonu protonu changed the title [do not land] This runs matmuls where rfactors for A and B are [M, K] and [K, N] respectively. Support allocation domain in the matmul scheduler May 15, 2024
@protonu
Copy link
Collaborator Author

protonu commented May 15, 2024

!build

Comment on lines +737 to +740
// A LoadMatrixTranspose is needed is there's a rfactor which doesn't
// match the alloc/root domain. If there's no rfactor, then we
// need a transpose if the alloc domain and root don't match. When we say
// match, we refer to the innermost iter domain.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this assumes the allocation domain of the producer matches its rfactor domain (and out_tvs root).

Comment on lines +2867 to +2869
if (tv1->hasAllocation()) {
tv1t->setAllocationDomain({tv1t->axis(0), tv1t->axis(1)}, true);
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the scheduler should actually be the one setting the allocation domain for this tensor. It would set the allocation domain on the register tensor that is the argument to the MmaOp since that is where the TN requirement comes from, then propagate allocation domain back to the register tensor that gets loaded from smem (i.e. the ldmatrix output).

Comment on lines -914 to -919
// For Turing and Ampere, the layout of the MmaOp is always TN
NVF_ERROR(
mma_layout == MmaLayout::TN,
"MMAs in Turing and Ampere are TN only, transpose is handled either "
"via ldmatrix.trans for fp16 or explicitly for other types.");

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we could use a utility propagateMmaProducerLayouts(MmaOp* mma_op, const std::vector<TensorView*>& to) which sets the operand allocation domains for the MmaOp based on its macro, then propagates that upstream to the given tensors (inclusive). We'd call it like propagateMmaProducerLayouts(mma_op, {acr, bcr}); I think.

@jjsjann123 jjsjann123 marked this pull request as ready for review May 15, 2024 20:42
Copy link
Collaborator

@jjsjann123 jjsjann123 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. I'll let @jacobhinkle stamp.

splitk_sum->axis(-1)->parallelize(ParallelType::Vectorize);
}

bool needsTranposedLoad(TensorView* producer, TensorView* consumer) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
bool needsTranposedLoad(TensorView* producer, TensorView* consumer) {
bool needsTransposedLoad(TensorView* producer, TensorView* consumer) {

{std::make_pair(acw_smem, &acr), std::make_pair(bcw_smem, &bcr)}) {
auto producer = tv_smem;
auto consumer = tv_smem->uses().at(0)->output(0)->as<TensorView>();
auto toTranspose = needsTranposedLoad(producer, consumer);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto toTranspose = needsTranposedLoad(producer, consumer);
auto toTranspose = needsTransposedLoad(producer, consumer);

const auto map =
PairwiseRootDomainMap(producer, consumer).mapProducerToConsumer();
auto maybeProducerAlloc = producer->getMaybeAllocationDomain();
auto maybeConsumerRFactor = consumer->getMaybeRFactorDomain();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: let's assert on std::is_permutation. Otherwise we can falsely claim needsTransposedLoad.

*tv_r = ldst->out()->as<TensorView>();
ldst->setOpType(
toTranspose ? LoadStoreOpType::LdMatrixTranspose
: LoadStoreOpType::LdMatrix);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity. What's planned to change for the follow up PR when you say?

In a follow-up PR I can handle LdMatrix/LdMatrixTrans.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is referring to us inferring the ptx instruction based on allocation domain. I think we need to look at consumer alloc to consumer root and compose that with producer rfactor to producer alloc. The resulting composition can tell us whether a transpose is needed as it combines both the allocation information and the semantic information in consumer's root to rfactor domain.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for filling in that~~ 🙇

@@ -57,11 +57,11 @@ void makeTile(TensorView* tv, std::vector<int64_t> tile_sizes);

//! Order the inner tile dimensions as the original order in
//! root domain. Also putting broadcast domains on the left.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in the original order in allocation domain instead?

auto maybe_root =
getMaybeRootIfInnermostTiled(leaf_id, maybe_rfactor_id_set);
auto maybe_alloc_domain =
getMaybeAllocationIfInnermostTiled(leaf_id, id_set);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nitpick: I think the logic here assumes rfactor and allocation are just permutation? Should we add an assert somewhere?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I guess if we are doing that somewhere here (or maybe it's already done in an even earlier spot, please ignore the comment on the other place where I asked to assert the permutation.

@protonu
Copy link
Collaborator Author

protonu commented May 21, 2024

Thanks for the reviews, but given our design has evolved. I'll create a new PR.

@protonu protonu closed this May 21, 2024
jacobhinkle added a commit that referenced this pull request May 23, 2024
This replaces the `CombineMulSum` class with `MatmulPattern` in the
Matmul scheduler. Additionally, we use these matmul patterns to
determine the problem layout, IterDomain roles, and TensorView roles.
The allocation domain is used to determine the problem layout. The
matmul scheduler is updated to reject segments whose input allocation
domains are non-trivial (until that is supported eg. by #2226).

Note that this does not add handling of `MatmulOp` and `LinearOp` in the
matmul scheduler. That will be done next in #2236 or similar.

---------

Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
@protonu protonu deleted the pbasu_experiment_handle_mkn_k_unitstride branch June 4, 2024 16:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants