Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -875,4 +875,28 @@ std::unordered_map<ValGroup, IterDomain*> updateValGroupIdMap(
return new_map;
}

// Mostly just copied from ComputeAtMap::validateAndPropagatePType
void IdModel::validateAndPropagatePType() {
for (const ValGroup& loop_group :
idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) {
ParallelType common_ptype = ParallelType::Serial;
for (Val* id : *loop_group) {
auto id_ptype = id->as<IterDomain>()->getParallelType();
NVF_ERROR(
id_ptype == common_ptype || id_ptype == ParallelType::Serial ||
common_ptype == ParallelType::Serial,
"Issue validating parallel type disjoint ptype is, ",
common_ptype,
" but found in the set the id: ",
id->toString());
common_ptype =
common_ptype == ParallelType::Serial ? id_ptype : common_ptype;
}

for (auto id : *loop_group) {
id->as<IterDomain>()->parallelize(common_ptype);
}
}
}

} // namespace nvfuser
5 changes: 5 additions & 0 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ class IdModel : public PolymorphicBase {
// replayed expression and adding potential mappings through the expression.
Expr* addReplayAs(std::vector<IterDomain*> new_inputs, Expr* expr);

//! Run through disjoint sets in the LOOP graph, make sure there's only one
//! non-serial parallel type in each disjoint set, set the parallel type of
//! all IterDomains in the disjoint set to that PType.
void validateAndPropagatePType();

protected:
// Fills id_uses_ and id_definitions_ for all IterDomains active in the
// fusion.
Expand Down
29 changes: 29 additions & 0 deletions tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2585,4 +2585,33 @@ TEST_F(IdModelTest, LoopPromotionCoverage) {
}
}

TEST_F(IdModelTest, ParallelTypePropagation) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);

auto tv1 = set(tv0);
auto tv2 = set(tv1);
fusion.addOutput(tv2);

tv2->split(0, 4);
TransformPropagatorWithCheck propagator(tv2);
MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator);

inlineMost();

tv2->axis(0)->parallelize(ParallelType::BIDx);
tv2->axis(1)->parallelize(ParallelType::TIDx);

IdModel id_model(&fusion);
id_model.validateAndPropagatePType();

EXPECT_EQ(tv1->axis(0)->getParallelType(), tv2->axis(0)->getParallelType())
<< "Parallel type propagation failed";
EXPECT_EQ(tv1->axis(1)->getParallelType(), tv2->axis(1)->getParallelType())
<< "Parallel type propagation failed";
}

} // namespace nvfuser