MmaOp - consistency checks for basic matmul#131
Conversation
5aa83df to
34f252e
Compare
|
!build |
34f252e to
b624b73
Compare
|
!build |
zasdfgbnm
left a comment
There was a problem hiding this comment.
Posting my existing comments.
944d183 to
2f6e541
Compare
|
!build |
| TORCH_CHECK( | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(), | ||
| "input layout has not be set for MmaOp"); | ||
| TORCH_CHECK( | ||
| layout == | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(), | ||
| "input layout from test and MmaOp do not match"); |
There was a problem hiding this comment.
nit: I think layout == ir_utils::getMmaOps(fusion.get()).front()->inputLayout() should be sufficient.
There was a problem hiding this comment.
The tricky part is that inputLayout() return optional. I made it like this because MmaOp can be created for Vals that are instances of TensorView or TensorIndex. For first layout can be created (unless inputs are incorrectly defined), for second I'm not sure how to handle it.
So currently, if MmaOp is created with TensorIndex then MAxes/NAxes/KAxes/BatchAxes/layout attributes are default initialized (empty std::vector<int> or empty c10::optional<MmaOptions::MmaInputLayout>).
There was a problem hiding this comment.
TensorIndex should only appear during lowering. If not, then this is an error. So you don't need to worry about TensorIndex in fusion IR.
| TORCH_CHECK( | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(), | ||
| "input layout has not be set for MmaOp"); | ||
| TORCH_CHECK( | ||
| layout == | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(), | ||
| "input layout from test and MmaOp do not match"); |
6b36c7e to
4a6571f
Compare
|
!build |
4a6571f to
2f823e9
Compare
|
!build |
zasdfgbnm
left a comment
There was a problem hiding this comment.
Posting some final comments
| } | ||
| }; | ||
|
|
||
| const auto validateOutputDetails = [](const TensorViewDetails& details, |
There was a problem hiding this comment.
for the case of batched matmul, should we allowing having broadcast and more than two concrete domains?
There was a problem hiding this comment.
We don't have (yet) support for batch, but I checked what are the domains in the MmaOp output for single test with strided batches:
NVFuserTest.FusionAmpereStridedBatchedMatmulTN_CUDA
and it looks like this:
T4_l [ iS18{i0}, iS19{i2}, iS20{i6}, iS21{i3}, rS22{i4} ]
I'm not sure if the test represents the final approach for strided batches so broadcasts could appear there.
For now I will keep the current implementation but I will add comment:
// TODO: revise rules when add support for batch gemms
| TORCH_CHECK( | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().has_value(), | ||
| "input layout has not be set for MmaOp"); | ||
| TORCH_CHECK( | ||
| layout == | ||
| ir_utils::getMmaOps(fusion.get()).front()->inputLayout().value(), | ||
| "input layout from test and MmaOp do not match"); |
There was a problem hiding this comment.
TensorIndex should only appear during lowering. If not, then this is an error. So you don't need to worry about TensorIndex in fusion IR.
2f823e9 to
d31dda6
Compare
|
!build |
csrc/ir_nodes.cpp
Outdated
| if (details.bcasts.empty()) { | ||
| TORCH_INTERNAL_ASSERT(false, desc, ": has no broadcast domains."); | ||
| } |
There was a problem hiding this comment.
nit:
| if (details.bcasts.empty()) { | |
| TORCH_INTERNAL_ASSERT(false, desc, ": has no broadcast domains."); | |
| } | |
| TORCH_INTERNAL_ASSERT(!details.bcasts.empty(), desc, ": has no broadcast domains."); |
There was a problem hiding this comment.
I missed this, good point. I updated PR and if there are no issues I will merge.
Thanks!
csrc/ir_nodes.cpp
Outdated
| if (!details.rdomains.empty()) { | ||
| TORCH_INTERNAL_ASSERT(false, desc, ": has reduction domains."); | ||
| } |
There was a problem hiding this comment.
nit:
| if (!details.rdomains.empty()) { | |
| TORCH_INTERNAL_ASSERT(false, desc, ": has reduction domains."); | |
| } | |
| TORCH_INTERNAL_ASSERT(details.rdomains.empty(), desc, ": has reduction domains."); |
csrc/ir_nodes.cpp
Outdated
| if (details.cdomains.size() < expected_gemm_cdomains) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| false, | ||
| desc, | ||
| ": has unsupported number of concrete domains, expected at least ", | ||
| expected_gemm_cdomains, | ||
| ", got ", | ||
| details.cdomains.size()); | ||
| } |
There was a problem hiding this comment.
nit:
| if (details.cdomains.size() < expected_gemm_cdomains) { | |
| TORCH_INTERNAL_ASSERT( | |
| false, | |
| desc, | |
| ": has unsupported number of concrete domains, expected at least ", | |
| expected_gemm_cdomains, | |
| ", got ", | |
| details.cdomains.size()); | |
| } | |
| TORCH_INTERNAL_ASSERT( | |
| details.cdomains.size() >= expected_gemm_cdomains, | |
| desc, | |
| ": has unsupported number of concrete domains, expected at least ", | |
| expected_gemm_cdomains, | |
| ", got ", | |
| details.cdomains.size()); |
csrc/ir_nodes.cpp
Outdated
| if (!details.bcasts.empty()) { | ||
| TORCH_INTERNAL_ASSERT(false, desc, ": has broadcast domains."); | ||
| } |
There was a problem hiding this comment.
nit:
| if (!details.bcasts.empty()) { | |
| TORCH_INTERNAL_ASSERT(false, desc, ": has broadcast domains."); | |
| } | |
| TORCH_INTERNAL_ASSERT(details.bcasts.empty(), desc, ": has broadcast domains."); |
csrc/ir_nodes.cpp
Outdated
| if (details.rdomains.empty()) { | ||
| TORCH_INTERNAL_ASSERT(false, desc, ": has no reduction domains."); | ||
| } |
There was a problem hiding this comment.
nit:
| if (details.rdomains.empty()) { | |
| TORCH_INTERNAL_ASSERT(false, desc, ": has no reduction domains."); | |
| } | |
| TORCH_INTERNAL_ASSERT(!details.rdomains.empty(), desc, ": has no reduction domains."); |
csrc/ir_nodes.cpp
Outdated
| if (details.cdomains.size() < expected_gemm_cdomains) { | ||
| TORCH_INTERNAL_ASSERT( | ||
| false, | ||
| desc, | ||
| ": has unsupported number of concrete domains, expected at least ", | ||
| expected_gemm_cdomains, | ||
| ", got ", | ||
| details.cdomains.size()); | ||
| } |
There was a problem hiding this comment.
nit:
| if (details.cdomains.size() < expected_gemm_cdomains) { | |
| TORCH_INTERNAL_ASSERT( | |
| false, | |
| desc, | |
| ": has unsupported number of concrete domains, expected at least ", | |
| expected_gemm_cdomains, | |
| ", got ", | |
| details.cdomains.size()); | |
| } | |
| TORCH_INTERNAL_ASSERT( | |
| details.cdomains.size() >= expected_gemm_cdomains, | |
| desc, | |
| ": has unsupported number of concrete domains, expected at least ", | |
| expected_gemm_cdomains, | |
| ", got ", | |
| details.cdomains.size()); |
- move some checks from compile time checks in matmul scheduler to MmaOp constructor, - move input layout check from matmul scheduler compile time checks to MmaOp class, - extend the set of attributes associated with MmaOp, - update compile time checks in matmul scheduler,
d31dda6 to
a992c0e
Compare
The goal of this PR:
The scope of changes:
MmaOpconstructor,Tests will be re-enabled with follow up PR that will add in MmaOp handling of scenarios covered by these tests.
Verification results:
.build/bin/nvfuser_testsall tests passedcc @mmigdal-nv