Conversation
|
Fuser/tests/cpp/test_combine_mul_sum.cpp Lines 134 to 146 in 5b5ec8f
|
This also adds testing of exact mapping to the node tests (WIP).
Use IdModel to define ID roles (and hence TV roles). Also use allocation domains to determine whether each operand has K as its inner dimension.
Tests pass!
This replaces the `CombineMulSum` class with `MatmulPattern` in the Matmul scheduler. Additionally, we use these matmul patterns to determine the problem layout, IterDomain roles, and TensorView roles. The allocation domain is used to determine the problem layout. The matmul scheduler is updated to reject segments whose input allocation domains are non-trivial (until that is supported eg. by #2226). Note that this does not add handling of `MatmulOp` and `LinearOp` in the matmul scheduler. That will be done next in #2236 or similar. --------- Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com> Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
Co-authored-by: Priya Mishra <52657555+Priya2698@users.noreply.github.com>
|
!build |
csrc/scheduler/mma_utils.cpp
Outdated
| dim_roles[exact_graph.toGroup(A->axis(-1))] = MatmulDomain::K; | ||
| NVF_ERROR(A->nDims() > 0 && B->nDims() > 0); | ||
| size_t m_and_k_dims = 0; | ||
| if (A->nDims() == 1 && B->nDims() == 1) { | ||
| NVF_ERROR( | ||
| false, "MatmulOp node should not be created when both inputs are 1D"); | ||
| } else if (A->nDims() == 1) { | ||
| // Missing M dimension | ||
| dim_roles[exact_graph.toGroup(B->axis(-1))] = MatmulDomain::N; | ||
| m_and_k_dims = 1; | ||
| } else if (B->nDims() == 1) { | ||
| // Missing N dimension | ||
| dim_roles[exact_graph.toGroup(A->axis(-2))] = MatmulDomain::M; | ||
| m_and_k_dims = 1; |
There was a problem hiding this comment.
I see. My suggestion was that the role information is embedded in the position of the iterdomain in the mapping output. For eg: out_size-3, out_size-2, out_size-1 are M, N, K respectively.
Could you clarify which approach is erroneous -- mappingMatmulOpIterDomain or this PR? Why should iS5 be M instead of batch?
tests/cpp/test_matmul_scheduler.cpp
Outdated
| TEST_F(MatmulSchedulerTest, Llama2FFN) { | ||
| NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); | ||
|
|
||
| for (bool enable_fusion : {false, true}) { |
There was a problem hiding this comment.
I would recommend parametrizing this, the parametrization pattern can be used with the SegmentMatmulOpPrologue and SegmentLinearOpPrologue tests as well to test with all three combinations of schedulers as listed in this comment --
// TODO: Once we can control the ExprEval and Matmul schedulers via options, run
// this test with all three combinations (with and without each scheduler, but
// at least one enabled).
There was a problem hiding this comment.
Oh that's a good idea
There was a problem hiding this comment.
I added a parametrization for the Llamma2FFN test, but the prologue tests are actually currently failing when fusion is enabled, so I will leave that for a separate PR (that might be fixed by #2309).
And enable 1d/1d linear tests
|
!build |
|
!build |
Priya2698
left a comment
There was a problem hiding this comment.
LGTM. Thanks for addressing the comments.
The purpose of this PR is to enable the NVFuser matmul scheduler to operate on the new
MatmulOpanLinearOpnodes. That means the matmul scheduler can optionally schedule segments if they are not supported by ATen'smatmulandlinearfunctions.Specifically, this PR:
MatmulOpandLinearOpto the set ofMatmulPatterns detected. Previously onlyMmaOpand combined mul-sum patterns were detected.MatmulOpandLinearOpto fixed broadcast+MmaOp patterns.EnableOption::FuseMatmuland associatedNVFUSER_ENABLE=fuse_matmuloption to enable the automatic scheduler to accept matmul patterns consisting of aMatmulOporLinearOp. By default, the matmul scheduler will not accept segments containingMatmulOporLinearOppatterns, meaning all those nodes will be computed using theExprEvalscheduler (ATen).Not all cases that are supported by these IR nodes can be translated to an
MmaOp. In particular:These correspond to test cases in the two new tests. See the
TODOcomments for descriptions of cases we plan to support but cannot yet translate.There is also a commented out test case for
LinearOp: M=N=1 which should not be translated to aLinearOpat all.