From 9fc149648cd59b14ab24e8dff66f7e5e88aba828 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 7 Jun 2024 13:54:19 -0700 Subject: [PATCH] Propagate parallel types in loop graph --- csrc/id_model/id_model.cpp | 24 ++++++++++++++++++++++++ csrc/id_model/id_model.h | 5 +++++ tests/cpp/test_id_model.cpp | 29 +++++++++++++++++++++++++++++ 3 files changed, 58 insertions(+) diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 77e105e1b6b..a8365ee0d9e 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -875,4 +875,28 @@ std::unordered_map updateValGroupIdMap( return new_map; } +// Mostly just copied from ComputeAtMap::validateAndPropagatePType +void IdModel::validateAndPropagatePType() { + for (const ValGroup& loop_group : + idGraph(IdMappingMode::LOOP).disjointValSets().disjointSets()) { + ParallelType common_ptype = ParallelType::Serial; + for (Val* id : *loop_group) { + auto id_ptype = id->as()->getParallelType(); + NVF_ERROR( + id_ptype == common_ptype || id_ptype == ParallelType::Serial || + common_ptype == ParallelType::Serial, + "Issue validating parallel type disjoint ptype is, ", + common_ptype, + " but found in the set the id: ", + id->toString()); + common_ptype = + common_ptype == ParallelType::Serial ? id_ptype : common_ptype; + } + + for (auto id : *loop_group) { + id->as()->parallelize(common_ptype); + } + } +} + } // namespace nvfuser diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index b08f1e31235..5de3e32cc0a 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -210,6 +210,11 @@ class IdModel : public PolymorphicBase { // replayed expression and adding potential mappings through the expression. Expr* addReplayAs(std::vector new_inputs, Expr* expr); + //! Run through disjoint sets in the LOOP graph, make sure there's only one + //! non-serial parallel type in each disjoint set, set the parallel type of + //! all IterDomains in the disjoint set to that PType. + void validateAndPropagatePType(); + protected: // Fills id_uses_ and id_definitions_ for all IterDomains active in the // fusion. diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index 9f720d41a8c..5bc1b1a9ae4 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -2585,4 +2585,33 @@ TEST_F(IdModelTest, LoopPromotionCoverage) { } } +TEST_F(IdModelTest, ParallelTypePropagation) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion.addInput(tv0); + + auto tv1 = set(tv0); + auto tv2 = set(tv1); + fusion.addOutput(tv2); + + tv2->split(0, 4); + TransformPropagatorWithCheck propagator(tv2); + MaxRootDomainInfoSpanningTree(tv2).traverse(&propagator); + + inlineMost(); + + tv2->axis(0)->parallelize(ParallelType::BIDx); + tv2->axis(1)->parallelize(ParallelType::TIDx); + + IdModel id_model(&fusion); + id_model.validateAndPropagatePType(); + + EXPECT_EQ(tv1->axis(0)->getParallelType(), tv2->axis(0)->getParallelType()) + << "Parallel type propagation failed"; + EXPECT_EQ(tv1->axis(1)->getParallelType(), tv2->axis(1)->getParallelType()) + << "Parallel type propagation failed"; +} + } // namespace nvfuser