Skip to content

Translate MatmulOp and LinearOp#2236

Merged
jacobhinkle merged 61 commits intomainfrom
translate_matmul_op
May 30, 2024
Merged

Translate MatmulOp and LinearOp#2236
jacobhinkle merged 61 commits intomainfrom
translate_matmul_op

Conversation

@jacobhinkle
Copy link
Collaborator

@jacobhinkle jacobhinkle commented May 13, 2024

The purpose of this PR is to enable the NVFuser matmul scheduler to operate on the new MatmulOp an LinearOp nodes. That means the matmul scheduler can optionally schedule segments if they are not supported by ATen's matmul and linear functions.

Specifically, this PR:

  1. Adds MatmulOp and LinearOp to the set of MatmulPatterns detected. Previously only MmaOp and combined mul-sum patterns were detected.
  2. Enable translation of MatmulOp and LinearOp to fixed broadcast+MmaOp patterns.
  3. Introduces EnableOption::FuseMatmul and associated NVFUSER_ENABLE=fuse_matmul option to enable the automatic scheduler to accept matmul patterns consisting of a MatmulOp or LinearOp. By default, the matmul scheduler will not accept segments containing MatmulOp or LinearOp patterns, meaning all those nodes will be computed using the ExprEval scheduler (ATen).

Not all cases that are supported by these IR nodes can be translated to an MmaOp. In particular:

  1. gemv cases where one operand is 1D are not supported
  2. Cases with multiple batch dimensions.
  3. Cases where "batch" dimensions must be added by unsqueezing one dimension. Those new dimensions are indistinguishable from e.g. multiple M dimensions which do not yet support.

These correspond to test cases in the two new tests. See the TODO comments for descriptions of cases we plan to support but cannot yet translate.

There is also a commented out test case for LinearOp: M=N=1 which should not be translated to a LinearOp at all.

@jacobhinkle
Copy link
Collaborator Author

jacobhinkle commented May 13, 2024

This test is failing: This was fixed in #2272 .

// We are broadcating to a tensor that will have too many dims
// to be valid for a mma op.
std::vector<bool> bcast_dims(tv0->nDims() + 2, false);
bcast_dims.at(bcast_dims.size() - 2) = true;
bcast_dims.at(bcast_dims.size() - 3) = true;
auto tv0b = broadcast(tv0t, bcast_dims);
bcast_dims.at(bcast_dims.size() - 2) = false;
bcast_dims.at(bcast_dims.size() - 3) = true;
bcast_dims.at(bcast_dims.size() - 4) = true;
auto tv1b = broadcast(tv1t, bcast_dims);
auto tv2 = mul(tv0b, tv1b);
auto tv3 = sum(tv2, {-1});
fusion.addOutput(tv3);

I haven't implemented this condition because I don't think we really want it. Rather, we would prefer to accept patterns with multiple M, N, K, or Batch dims and simply canonicalize them via reorder/merge at the beginning of scheduling. Still, if we do support that then we should update the test to actually check that they are properly scheduled; until then we should probably keep the check strict.

jacobhinkle added a commit that referenced this pull request May 23, 2024
This replaces the `CombineMulSum` class with `MatmulPattern` in the
Matmul scheduler. Additionally, we use these matmul patterns to
determine the problem layout, IterDomain roles, and TensorView roles.
The allocation domain is used to determine the problem layout. The
matmul scheduler is updated to reject segments whose input allocation
domains are non-trivial (until that is supported eg. by #2226).

Note that this does not add handling of `MatmulOp` and `LinearOp` in the
matmul scheduler. That will be done next in #2236 or similar.

---------

Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
@jacobhinkle jacobhinkle requested a review from Priya2698 May 29, 2024 23:32
@jacobhinkle
Copy link
Collaborator Author

!build

Comment on lines +1636 to +1649
dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K;
NVF_ERROR(A->nDims() > 0 && B->nDims() > 0);
size_t m_and_k_dims = 0;
if (A->nDims() == 1 && B->nDims() == 1) {
NVF_ERROR(
false, "MatmulOp node should not be created when both inputs are 1D");
} else if (A->nDims() == 1) {
// Missing M dimension
dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N;
m_and_k_dims = 1;
} else if (B->nDims() == 1) {
// Missing N dimension
dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M;
m_and_k_dims = 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see. My suggestion was that the role information is embedded in the position of the iterdomain in the mapping output. For eg: out_size-3, out_size-2, out_size-1 are M, N, K respectively.

Could you clarify which approach is erroneous -- mappingMatmulOpIterDomain or this PR? Why should iS5 be M instead of batch?

TEST_F(MatmulSchedulerTest, Llama2FFN) {
NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0);

for (bool enable_fusion : {false, true}) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend parametrizing this, the parametrization pattern can be used with the SegmentMatmulOpPrologue and SegmentLinearOpPrologue tests as well to test with all three combinations of schedulers as listed in this comment --


// TODO: Once we can control the ExprEval and Matmul schedulers via options, run
// this test with all three combinations (with and without each scheduler, but
// at least one enabled).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh that's a good idea

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a parametrization for the Llamma2FFN test, but the prologue tests are actually currently failing when fusion is enabled, so I will leave that for a separate PR (that might be fixed by #2309).

@jacobhinkle
Copy link
Collaborator Author

!build

@jacobhinkle
Copy link
Collaborator Author

!build

Copy link
Collaborator

@Priya2698 Priya2698 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for addressing the comments.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants