Skip to content

Canonicalize matmul dims in scheduleMatmul using dimension roles#2376

Merged
jacobhinkle merged 23 commits intomainfrom
canonicalize_dims
Jun 25, 2024
Merged

Canonicalize matmul dims in scheduleMatmul using dimension roles#2376
jacobhinkle merged 23 commits intomainfrom
canonicalize_dims

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented Jun 10, 2024

This updates scheduleMatmul to use the dimension roles introduced in #2303 to reorder and merge dims within each role. That means we can schedule fusions with more than one tensor in each role.

Two tests are included so far:

  • MultipleMDims which performs [ M1, M2, K] @ [N, K]. This corresponds to torch.linear with 3D input i.e. a Linear layer with two "batch" dimensions.
  • MultipleMDimsBatch which performs [M1, B, M2, K] @ [B, N, K]. This shows that M dimensions need not be contiguous. Note that vectorization is unaffected in this example since K is innermost.
    I plan to add more tests to explore more scenarios.

Note that we cannot yet handle cases where K is not the innermost dimension, and I have not yet tested cases with M as the innermost output dimension. These will likely require us to modify the fusion definition to place those dimensions in the right spot in the MmaOp.

@jacobhinkle
Copy link
Collaborator Author

!build

Comment on lines -81 to -90
//! Matches the following matmul patterns.
//! Matmul: A x B, alpha * A x B
//! Matmul + Bias (addmm): A x B + C, alpha * A x B + C, A x B + beta * C,
//! alpha * A x B + beta * C
//! Linear: A x B / A x B + C
//! Assumptions:
//! 1. For simplicity, we assume the MmaOp to be in the first operand.
//! 2. For linear ([M, K], [N, K]), alpha, beta parameters are nullptr.
bool matchMatmulPatterns(const UnaryOp* cast_op, MatmulInputs* matmul_inp);

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is undefined/uncalled so should've been removed in a previous PR.

@jacobhinkle jacobhinkle marked this pull request as ready for review June 13, 2024 20:03
Comment on lines +299 to +302
// Also check that dims within each role are consecutive with one another
// for this pattern.
// TODO: Lift this requirement by modifying the definition or setting
// allocation domains to support this setting in MmaOp
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this condition is sufficient to avoid the problematic cases like the included test MultipleNonConsecutiveNDims.

Comment on lines +1890 to +1891
// Insert the device dims first, then skip them when inserting dims from each
// other role
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ensuring device dim groups are placed outside, even if that means we have non-consecutive dims within each role.

}

// This is a tougher test where we insert a batch dim between the M dims
TEST_F(GPUTTensorCoreTest, MultipleMDimsBatch) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A batch dim can be inserted between M dims, but not the K dim. This is because the batch dimension gets parallelized, so that M1 and M2 will be contiguous with one another in the smem tensor.

@jacobhinkle jacobhinkle requested a review from zasdfgbnm June 13, 2024 20:11
Comment on lines +3239 to +3240
tv0 = broadcast(tv0, {false, false, true, false});
tv1 = broadcast(tv1, {true, true, false, false});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why Ms are discontiguous?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops. Error in the comment. This is the consecutive Ms case. I'll rename it as such as well.

Also add a disabled non-consecutive version
NVF_ERROR(it != id_roles.end());
role_order.pushBack(it->second);
}
NVF_ERROR(
Copy link
Collaborator

@zasdfgbnm zasdfgbnm Jun 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So scheduleMatmul can now partially support having multiple M/N dims, but we are still rejecting it in the scheduler for now?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, should this be the following?

if (role_order.size() != 3 && role_order.size() != 4) {
  return "Expected either….";
}

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wait, should this be the following?

Ah! Good catch!

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So scheduleMatmul can now partially support having multiple M/N dims, but we are still rejecting it in the scheduler for now?

This check will not complain if there are multiple M or N dims, since we build role_order as the ordering of the roles not all the dimensions that constitute all the roles. For example we might have A[M1, K1, M2] and B[K1, N1] and ordering [M, N, K].

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So scheduleMatmul can now partially support having multiple M/N dims, but we are still rejecting it in the scheduler for now?

This check will not complain if there are multiple M or N dims, since we build role_order as the ordering of the roles not all the dimensions that constitute all the roles. For example we might have A[M1, K1, M2] and B[K1, N1] and ordering [M, N, K].

But wouldn't the role_order.size() check reject the fusion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, role_order would be [M, N, K] in that case. It doesn't hold dimensions, but roles.

.reshape({M, N1, N2});
NVF_CHECK(cg_outputs[0].allclose(tref, 0.0001, 0.0001));
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also need a MultipleConsecutiveKDims? Or, maybe combine all these tests together to do a (B, B, M, M, N, N, K, K) matmul?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah they are pretty orthogonal so I think I could just combine them into one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MmaOp currently has a restriction that there can only be a single K dimension. I added a check for this in isMatmulFusionDefinitionSupported since such a pattern could be created as a mul-sum; it cannot be created with matmul or linear.

Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
fusion.addInput(tv0);
fusion.addInput(tv1);

// M1, N, K, M2
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be M1, N, M2, K?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this is M1, N, K, M2, it's just that I screwed up the definition. I think I will combine this with the other disabled test and have non-consecutive, M, N, and K dims. Similarly I'll combine the enabled consecutive tests


fusion.addInput(tv0);
fusion.addInput(tv1);

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// M1, B, M2, N, K

@jacobhinkle jacobhinkle merged commit 7bc7a08 into main Jun 25, 2024
@jacobhinkle jacobhinkle deleted the canonicalize_dims branch June 25, 2024 15:34
protonu pushed a commit that referenced this pull request Jun 25, 2024
This updates `scheduleMatmul` to use the dimension roles introduced in
#2303 to reorder and merge dims within each role. That means we can
schedule fusions with more than one tensor in each role.

Two tests are included so far:
- MultipleMDims which performs [ M1, M2, K] @ [N, K]. This corresponds
to `torch.linear` with 3D input i.e. a Linear layer with two "batch"
dimensions.
- MultipleMDimsBatch which performs [M1, B, M2, K] @ [B, N, K]. This
shows that M dimensions need not be contiguous. Note that vectorization
is unaffected in this example since K is innermost.
I plan to add more tests to explore more scenarios.

Note that we cannot yet handle cases where K is not the innermost
dimension, and I have not yet tested cases with M as the innermost
output dimension. These will likely require us to modify the fusion
definition to place those dimensions in the right spot in the MmaOp.

---------

Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants