Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions csrc/scheduler/mma_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1170,10 +1170,10 @@ RolesMapOpt getTensorRoles(

const auto findDims = [&dim_roles, &exact_graph](TensorView* tv) {
DimPresence has;
for (IterDomain* id :
TensorDomain::noReductions(tv->getMaybeRFactorDomain())) {
if (id->isBroadcast()) {
// Broadcast domains won't exact map to concrete domains so skip them
for (IterDomain* id : TensorDomain::noReductions(tv->getLeafDomain())) {
if (id->isBroadcast() || id->isDeviceDim()) {
// Broadcast and device domains won't exact map to concrete domains so
// skip them
continue;
}
const ValGroup& g = exact_graph.toGroup(id);
Expand Down Expand Up @@ -1414,16 +1414,16 @@ class MatmulPatternMatcher : IterVisitor {
ltv = getTensorviewPriorToCast(ltv);
rtv = getTensorviewPriorToCast(rtv);

std::vector<IterDomain*> lrf =
TensorDomain::noReductions(ltv->getMaybeRFactorDomain());
std::vector<IterDomain*> rrf =
TensorDomain::noReductions(rtv->getMaybeRFactorDomain());
std::vector<IterDomain*> lrf = TensorDomain::noDevices(
TensorDomain::noReductions(ltv->getLeafDomain()));
std::vector<IterDomain*> rrf = TensorDomain::noDevices(
TensorDomain::noReductions(rtv->getLeafDomain()));

// These sizes should match since ops::maybeBroadcast places BroadcastOps
// for implicit broadcasting.
NVF_ERROR(lrf.size() == rrf.size());
const std::vector<IterDomain*>& red_root =
rop->out()->as<TensorView>()->getRootDomain();
const std::vector<IterDomain*>& red_root = TensorDomain::noDevices(
rop->out()->as<TensorView>()->getRootDomain());
NVF_ERROR(red_root.size() == lrf.size());
// Find innermost M or N dimension in output
// We will assume for now that the output rfactor domain matches the
Expand Down Expand Up @@ -1537,10 +1537,10 @@ std::unordered_map<ValGroup, MatmulDomain> MatmulPattern::getDimRoles(
std::unordered_map<ValGroup, DimPresence> present_flags;
const auto recordPresence = [&exact_graph, &present_flags](
TensorView* tv, size_t tensor_num) {
for (IterDomain* id : tv->getMaybeRFactorDomain()) {
if (id->isReduction() || id->isBroadcast()) {
// ignore reductions and broadcasts since they don't exact map to
// problem dims
for (IterDomain* id : tv->getLeafDomain()) {
if (id->isReduction() || id->isBroadcast() || id->isDeviceDim()) {
// ignore device, reductions, and broadcasts since they don't exact map
// to problem dims in the generated kernel
continue;
}
const ValGroup& g = exact_graph.toGroup(id);
Expand Down