diff --git a/csrc/scheduler/mma_utils.cpp b/csrc/scheduler/mma_utils.cpp index af593d0b23a..4665ebf6f6b 100644 --- a/csrc/scheduler/mma_utils.cpp +++ b/csrc/scheduler/mma_utils.cpp @@ -1170,10 +1170,10 @@ RolesMapOpt getTensorRoles( const auto findDims = [&dim_roles, &exact_graph](TensorView* tv) { DimPresence has; - for (IterDomain* id : - TensorDomain::noReductions(tv->getMaybeRFactorDomain())) { - if (id->isBroadcast()) { - // Broadcast domains won't exact map to concrete domains so skip them + for (IterDomain* id : TensorDomain::noReductions(tv->getLeafDomain())) { + if (id->isBroadcast() || id->isDeviceDim()) { + // Broadcast and device domains won't exact map to concrete domains so + // skip them continue; } const ValGroup& g = exact_graph.toGroup(id); @@ -1414,16 +1414,16 @@ class MatmulPatternMatcher : IterVisitor { ltv = getTensorviewPriorToCast(ltv); rtv = getTensorviewPriorToCast(rtv); - std::vector lrf = - TensorDomain::noReductions(ltv->getMaybeRFactorDomain()); - std::vector rrf = - TensorDomain::noReductions(rtv->getMaybeRFactorDomain()); + std::vector lrf = TensorDomain::noDevices( + TensorDomain::noReductions(ltv->getLeafDomain())); + std::vector rrf = TensorDomain::noDevices( + TensorDomain::noReductions(rtv->getLeafDomain())); // These sizes should match since ops::maybeBroadcast places BroadcastOps // for implicit broadcasting. NVF_ERROR(lrf.size() == rrf.size()); - const std::vector& red_root = - rop->out()->as()->getRootDomain(); + const std::vector& red_root = TensorDomain::noDevices( + rop->out()->as()->getRootDomain()); NVF_ERROR(red_root.size() == lrf.size()); // Find innermost M or N dimension in output // We will assume for now that the output rfactor domain matches the @@ -1537,10 +1537,10 @@ std::unordered_map MatmulPattern::getDimRoles( std::unordered_map present_flags; const auto recordPresence = [&exact_graph, &present_flags]( TensorView* tv, size_t tensor_num) { - for (IterDomain* id : tv->getMaybeRFactorDomain()) { - if (id->isReduction() || id->isBroadcast()) { - // ignore reductions and broadcasts since they don't exact map to - // problem dims + for (IterDomain* id : tv->getLeafDomain()) { + if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) { + // ignore device, reductions, and broadcasts since they don't exact map + // to problem dims in the generated kernel continue; } const ValGroup& g = exact_graph.toGroup(id);