From 9d9ee5ecf420fe75b90c2230bfa8f8b12581648e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 24 May 2024 12:03:55 +0000 Subject: [PATCH] Compute matmul dim roles with no-devices leaf domain This fixes a bug introduced by #2272 in `test_multidevice` where we reject a matmul segment shaped like `[iDIDxMo, iMi, bN, iK]` for having too many M dimensions. Locally this still has a single M dimension so it is valid. This PR ignores device dims for the purposes of computing tensor roles and problem shape. Further issues we should look into: 1. As mentioned in #2272 we should proceed to handle multiple M, N, K, and Batch dimensions, although in this case the restriction was useful for surfacing this bug. 2. Even if the matmul scheduler is completely broken or disabled, the _reduction_ scheduler should have been able to schedule this fusion. However, it identified the reduction tensor as `isResharding` and removed it from the `reduction_tvs` list, causing a failure in `scheduleReduction`. We should clean up that check to be able to schedule this type of fusion as a reduction. 3. The rfactor domain is often used for scheduling utilities to inspect the logical size of tensors. However, because multidevice scheduling modifies the leaf domain before segmentation, we should probably audit our schedulers to ensure they use the leaf domain and ignore device dims where necessary. 4. I should also not forget to rerun `!build` before merging PRs :-). --- csrc/scheduler/mma_utils.cpp | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index af593d0b23a..4665ebf6f6b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1170,10 +1170,10 @@ RolesMapOpt getTensorRoles( const auto findDims = [&dim_roles, &exact_graph](TensorView* tv) { DimPresence has; - for (IterDomain* id : - TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { - if (id->isBroadcast()) { - // Broadcast domains won't exact map to concrete domains so skip them + for (IterDomain* id : TensorDomain::noReductions(tv->getLeafDomain())) { + if (id->isBroadcast() || id->isDeviceDim()) { + // Broadcast and device domains won't exact map to concrete domains so + // skip them continue; } const ValGroup& g = exact_graph.toGroup(id); @@ -1414,16 +1414,16 @@ class MatmulPatternMatcher : IterVisitor { ltv = getTensorviewPriorToCast(ltv); rtv = getTensorviewPriorToCast(rtv); - std::vector lrf = - TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); - std::vector rrf = - TensorDomain::noReductions(rtv->getMaybeRFactorDomain()); + std::vector lrf = TensorDomain::noDevices( + TensorDomain::noReductions(ltv->getLeafDomain())); + std::vector rrf = TensorDomain::noDevices( + TensorDomain::noReductions(rtv->getLeafDomain())); // These sizes should match since ops::maybeBroadcast places BroadcastOps // for implicit broadcasting. NVF_ERROR(lrf.size() == rrf.size()); - const std::vector& red_root = - rop->out()->as()->getRootDomain(); + const std::vector& red_root = TensorDomain::noDevices( + rop->out()->as()->getRootDomain()); NVF_ERROR(red_root.size() == lrf.size()); // Find innermost M or N dimension in output // We will assume for now that the output rfactor domain matches the @@ -1537,10 +1537,10 @@ std::unordered_map MatmulPattern::getDimRoles( std::unordered_map present_flags; const auto recordPresence = [&exact_graph, &present_flags]( TensorView* tv, size_t tensor_num) { - for (IterDomain* id : tv->getMaybeRFactorDomain()) { - if (id->isReduction() || id->isBroadcast()) { - // ignore reductions and broadcasts since they don't exact map to - // problem dims + for (IterDomain* id : tv->getLeafDomain()) { + if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) { + // ignore device, reductions, and broadcasts since they don't exact map + // to problem dims in the generated kernel continue; } const ValGroup& g = exact_graph.toGroup(id);