diff --git a/csrc/disjoint_set.h b/csrc/disjoint_set.h index 8d916ea2edf..618f768fff1 100644 --- a/csrc/disjoint_set.h +++ b/csrc/disjoint_set.h @@ -87,7 +87,8 @@ class VectorOfUniqueEntries { } // Returns true if any node was added - bool pushBack(const std::vector& other) { + template + bool pushBack(const std::vector& other) { bool any_added = false; for (const auto& entry : other) { auto added = pushBack(entry); diff --git a/csrc/id_model/id_model.cpp b/csrc/id_model/id_model.cpp index 33db3157ed6..c457ab0cf6b 100644 --- a/csrc/id_model/id_model.cpp +++ b/csrc/id_model/id_model.cpp @@ -543,6 +543,26 @@ StatefulInliningInfo buildStatefulInliningInfo( } } } + + // Siblings should always be mapped + auto consumer_tvs = ir_utils::filterByType(expr->outputs()); + if (consumer_tvs.size() > 1) { + auto all_consumer_ids = ir_utils::allIDsOf(consumer_tvs.vector().at(0)); + info.ordered_sibling_ids.pushBack( + {all_consumer_ids.begin(), all_consumer_ids.end()}); + for (const auto i : c10::irange(1, consumer_tvs.size())) { + auto consumer_tv_i = consumer_tvs.vector().at(i); + auto all_consumer_i_ids = ir_utils::allIDsOf(consumer_tv_i); + + auto sibling_map = + exact_graph.buildMapBetween(all_consumer_ids, all_consumer_i_ids); + + for (const auto& [c_id_1, c_ids] : sibling_map) { + NVF_ERROR(c_ids.size() == 1); + info.sibling_maps[c_id_1->as()].pushBack(c_ids); + } + } + } } return info; } @@ -565,6 +585,17 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) { } } } + + // Similarly maps all sibling domains + for (IterDomain* id : info.ordered_sibling_ids) { + auto entry_it = info.sibling_maps.find(id); + if (entry_it != info.sibling_maps.end()) { + const VectorOfUniqueEntries& sibling_ids = entry_it->second; + for (Val* sibling_id : sibling_ids) { + idGraph(IdMappingMode::LOOP).mapVals(id, sibling_id); + } + } + } } void IdModel::buildLoopGraph() { diff --git a/csrc/id_model/id_model.h b/csrc/id_model/id_model.h index 618599e634e..5b60ff474c6 100644 --- a/csrc/id_model/id_model.h +++ b/csrc/id_model/id_model.h @@ -36,6 +36,12 @@ struct StatefulInliningInfo { // root domains std::unordered_map> p2c_root_broadcast_resolution_map; + + // All IDs of all first siblings + VectorOfUniqueEntries ordered_sibling_ids; + + // Mappings to other sibling IDs from ordered_sibling_ids + std::unordered_map> sibling_maps; }; StatefulInliningInfo buildStatefulInliningInfo( diff --git a/tests/cpp/test_id_model.cpp b/tests/cpp/test_id_model.cpp index cd6f97d0c0a..192e5bd1b6e 100644 --- a/tests/cpp/test_id_model.cpp +++ b/tests/cpp/test_id_model.cpp @@ -2347,4 +2347,35 @@ TEST_F(IdModelTest, ValGraphBFS4) { ASSERT_EQ(tv4_to_tv0, tv4_to_tv0_ref); } +// Make sure domains of sibling tensors are all mapped together in the +// LOOP graph even when those tensors are not inlined. +TEST_F(IdModelTest, LoopGraphWithSibling) { + Fusion fusion; + FusionGuard fg(&fusion); + + auto tv0 = makeSymbolicTensor(3); + fusion.addInput(tv0); + auto welford_out_tvs = Welford(tv0, {2}); + auto avg = welford_out_tvs.avg; + fusion.addOutput(avg); + + // Random scheduling + avg->split(-1, 4); + avg->merge(0); + avg->split(0, 8); + TransformPropagatorWithCheck propagator(avg); + MaxRootDomainInfoSpanningTree(avg).traverse(&propagator); + + IdModel id_model(&fusion); + const auto& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + + for (auto welford_out : {welford_out_tvs.var_sum, welford_out_tvs.n}) { + for (const auto i : c10::irange(avg->nDims())) { + ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped( + avg->axis(i), welford_out->axis(i))) + << "Unmapped siblings: " << avg->axis(i)->toString() << ", " + << welford_out->axis(i)->toString(); + } + } +} } // namespace nvfuser