Skip to content

MmaOp - consistency checks for basic matmul#131

Merged
drzejan2 merged 1 commit intomainfrom
ab/MmaOp_def_consistency_checks
Apr 13, 2023
Merged

MmaOp - consistency checks for basic matmul#131
drzejan2 merged 1 commit intomainfrom
ab/MmaOp_def_consistency_checks

Conversation

@drzejan2
Copy link
Contributor

@drzejan2 drzejan2 commented Apr 5, 2023

The goal of this PR:

  • when MmaOp is constructed there are no checks to validate if provided parameters (inputs/outputs) are valid,
  • this resulted in implementing consistency checks as rules in compile time checks in matmul scheduler,
  • this PR covers part of follow-up tasks mentioned this comment in Enable matmul scheduler in segmenter #23

The scope of changes:

  • move some checks from compile time checks in matmul scheduler to MmaOp constructor,
  • gather M / N / K / Batch axes from MmaOp inputs/outputs,
  • move input layout check from matmul scheduler compile time checks to MmaOp constructor,
  • M / N / K / Batch axes and inputs' layout are stored as attributes in MmaOp object,

Tests will be re-enabled with follow up PR that will add in MmaOp handling of scenarios covered by these tests.

Verification results:

  • c++ tests
    • cmd:
      .build/bin/nvfuser_tests
    • result:
      all tests passed

cc @mmigdal-nv

@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch from 5aa83df to 34f252e Compare April 5, 2023 12:12
@drzejan2
Copy link
Contributor Author

drzejan2 commented Apr 5, 2023

!build

@drzejan2 drzejan2 requested review from naoyam and zasdfgbnm April 5, 2023 13:59
@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch from 34f252e to b624b73 Compare April 5, 2023 15:09
@drzejan2
Copy link
Contributor Author

drzejan2 commented Apr 5, 2023

!build

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting my existing comments.

@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch 2 times, most recently from 944d183 to 2f6e541 Compare April 7, 2023 14:44
@drzejan2
Copy link
Contributor Author

drzejan2 commented Apr 7, 2023

!build

Comment on lines +3180 to +3145
TORCH_CHECK(
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(),
"input layout has not be set for MmaOp");
TORCH_CHECK(
layout ==
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(),
"input layout from test and MmaOp do not match");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think layout == ir_utils::getMmaOps(fusion.get()).front()->inputLayout() should be sufficient.

Copy link
Contributor Author

@drzejan2 drzejan2 Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tricky part is that inputLayout() return optional. I made it like this because MmaOp can be created for Vals that are instances of TensorView or TensorIndex. For first layout can be created (unless inputs are incorrectly defined), for second I'm not sure how to handle it.

So currently, if MmaOp is created with TensorIndex then MAxes/NAxes/KAxes/BatchAxes/layout attributes are default initialized (empty std::vector<int> or empty c10::optional<MmaOptions::MmaInputLayout>).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorIndex should only appear during lowering. If not, then this is an error. So you don't need to worry about TensorIndex in fusion IR.

Comment on lines +3232 to +3197
TORCH_CHECK(
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(),
"input layout has not be set for MmaOp");
TORCH_CHECK(
layout ==
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(),
"input layout from test and MmaOp do not match");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch 2 times, most recently from 6b36c7e to 4a6571f Compare April 11, 2023 11:59
@drzejan2
Copy link
Contributor Author

!build

@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch from 4a6571f to 2f823e9 Compare April 12, 2023 09:57
@drzejan2
Copy link
Contributor Author

!build

@mmigdal-nv mmigdal-nv self-requested a review April 12, 2023 09:59
Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting some final comments

}
};

const auto validateOutputDetails = [](const TensorViewDetails& details,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the case of batched matmul, should we allowing having broadcast and more than two concrete domains?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have (yet) support for batch, but I checked what are the domains in the MmaOp output for single test with strided batches:

NVFuserTest.FusionAmpereStridedBatchedMatmulTN_CUDA

and it looks like this:

T4_l [ iS18{i0}, iS19{i2}, iS20{i6}, iS21{i3}, rS22{i4} ]

I'm not sure if the test represents the final approach for strided batches so broadcasts could appear there.

For now I will keep the current implementation but I will add comment:

// TODO: revise rules when add support for batch gemms

Comment on lines +3180 to +3145
TORCH_CHECK(
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(),
"input layout has not be set for MmaOp");
TORCH_CHECK(
layout ==
ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(),
"input layout from test and MmaOp do not match");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TensorIndex should only appear during lowering. If not, then this is an error. So you don't need to worry about TensorIndex in fusion IR.

@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch from 2f823e9 to d31dda6 Compare April 13, 2023 07:18
@drzejan2
Copy link
Contributor Author

!build

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Posting a few coding style change, please fix this before merge

Comment on lines +1506 to +1508
if (details.bcasts.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has no broadcast domains.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (details.bcasts.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has no broadcast domains.");
}
TORCH_INTERNAL_ASSERT(!details.bcasts.empty(), desc, ": has no broadcast domains.");

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this, good point. I updated PR and if there are no issues I will merge.

Thanks!

Comment on lines +1509 to +1511
if (!details.rdomains.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has reduction domains.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!details.rdomains.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has reduction domains.");
}
TORCH_INTERNAL_ASSERT(details.rdomains.empty(), desc, ": has reduction domains.");

Comment on lines +1512 to +1520
if (details.cdomains.size() < expected_gemm_cdomains) {
TORCH_INTERNAL_ASSERT(
false,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (details.cdomains.size() < expected_gemm_cdomains) {
TORCH_INTERNAL_ASSERT(
false,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
}
TORCH_INTERNAL_ASSERT(
details.cdomains.size() >= expected_gemm_cdomains,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());

Comment on lines +1526 to +1528
if (!details.bcasts.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has broadcast domains.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (!details.bcasts.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has broadcast domains.");
}
TORCH_INTERNAL_ASSERT(details.bcasts.empty(), desc, ": has broadcast domains.");

Comment on lines +1529 to +1531
if (details.rdomains.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has no reduction domains.");
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (details.rdomains.empty()) {
TORCH_INTERNAL_ASSERT(false, desc, ": has no reduction domains.");
}
TORCH_INTERNAL_ASSERT(!details.rdomains.empty(), desc, ": has no reduction domains.");

Comment on lines +1532 to +1540
if (details.cdomains.size() < expected_gemm_cdomains) {
TORCH_INTERNAL_ASSERT(
false,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
if (details.cdomains.size() < expected_gemm_cdomains) {
TORCH_INTERNAL_ASSERT(
false,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());
}
TORCH_INTERNAL_ASSERT(
details.cdomains.size() >= expected_gemm_cdomains,
desc,
": has unsupported number of concrete domains, expected at least ",
expected_gemm_cdomains,
", got ",
details.cdomains.size());

- move some checks from compile time checks in matmul scheduler to
  MmaOp constructor,
- move input layout check from matmul scheduler compile time checks
  to MmaOp class,
- extend the set of attributes associated with MmaOp,
- update compile time checks in matmul scheduler,
@drzejan2 drzejan2 force-pushed the ab/MmaOp_def_consistency_checks branch from d31dda6 to a992c0e Compare April 13, 2023 16:52
@drzejan2 drzejan2 merged commit 7a3b24e into main Apr 13, 2023
@drzejan2 drzejan2 deleted the ab/MmaOp_def_consistency_checks branch April 13, 2023 16:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants