Infer matmul dimension roles to compute vectorization#2303
Infer matmul dimension roles to compute vectorization#2303jacobhinkle merged 31 commits intomainfrom
Conversation
NOTE: this is a WIP draft that will likely be split into multiple smaller PRs. This is an attempt to generalize our matmul scheduler by doing the following: 1. Support more than 2 operands in our MatmulParams.SupportedVectorization struct 2. Properly infer vectorization for every operand, epilogue input, and output. 3. Compute a canonical dim ordering on ValGroups. This is used to compute vectorization properly but can be used for canonicalization of loop domains in scheduleMatmul in a future PR. 4. Schedule each tensor according to its supported vectorization. This might imply a new loop. For example if there are two outputs and one supports only vectorization width of 4 and the other 8, then we will unroll a loop of size 2 for the width-4 writes so that the outer loops are still inlined properly. This is in preparation for further generalization to accomodate multiple MmaOps in a single Fusion.
|
!build |
|
!build |
|
!build |
I had mistakenly thought we needed the leaf domain to handle hte multidevice cases, but we don't and that is confusing. I think this way is more reliable.
|
!build --diff |
|
I think the codediffs are just due to the uncommented tests. |
|
What if my mma output is (M=1024, N=1024), and my epilogue is: |
In that case, I would expect T3 to not be mapped as an |
|
Oh, I see. We are effectively rejecting view ops inside the fusion. And this rejection saves us from having to use the vectorization helper. |
Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
|
!build |
This failure was just due to using exact graph while we now use permissive graph instead.
This updates `scheduleMatmul` to use the dimension roles introduced in #2303 to reorder and merge dims within each role. That means we can schedule fusions with more than one tensor in each role. Two tests are included so far: - MultipleMDims which performs [ M1, M2, K] @ [N, K]. This corresponds to `torch.linear` with 3D input i.e. a Linear layer with two "batch" dimensions. - MultipleMDimsBatch which performs [M1, B, M2, K] @ [B, N, K]. This shows that M dimensions need not be contiguous. Note that vectorization is unaffected in this example since K is innermost. I plan to add more tests to explore more scenarios. Note that we cannot yet handle cases where K is not the innermost dimension, and I have not yet tested cases with M as the innermost output dimension. These will likely require us to modify the fusion definition to place those dimensions in the right spot in the MmaOp. --------- Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
This updates `scheduleMatmul` to use the dimension roles introduced in #2303 to reorder and merge dims within each role. That means we can schedule fusions with more than one tensor in each role. Two tests are included so far: - MultipleMDims which performs [ M1, M2, K] @ [N, K]. This corresponds to `torch.linear` with 3D input i.e. a Linear layer with two "batch" dimensions. - MultipleMDimsBatch which performs [M1, B, M2, K] @ [B, N, K]. This shows that M dimensions need not be contiguous. Note that vectorization is unaffected in this example since K is innermost. I plan to add more tests to explore more scenarios. Note that we cannot yet handle cases where K is not the innermost dimension, and I have not yet tested cases with M as the innermost output dimension. These will likely require us to modify the fusion definition to place those dimensions in the right spot in the MmaOp. --------- Co-authored-by: Gao, Xiang <qasdfgtyuiop@gmail.com>
This PR does the following:
RolesMaptoTensorRolesMapand introduceDimRolesMapwhich is a mapping fromValGrouptoMatmulDomain.ValGroups based on allocation domains of inputs and outputs. This is used to compute vectorization properly but can be used for canonicalization of loop domains in scheduleMatmul in a future PR.This is in preparation for further generalization to accomodate multiple MmaOps in a single Fusion.
Fixes #2169.