Skip to content
Merged
3 changes: 2 additions & 1 deletion csrc/disjoint_set.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ class VectorOfUniqueEntries {
}

// Returns true if any node was added
bool pushBack(const std::vector<T>& other) {
template <typename OtherType>
bool pushBack(const std::vector<OtherType>& other) {
bool any_added = false;
for (const auto& entry : other) {
auto added = pushBack(entry);
Expand Down
31 changes: 31 additions & 0 deletions csrc/id_model/id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -543,6 +543,26 @@ StatefulInliningInfo buildStatefulInliningInfo(
}
}
}

// Siblings should always be mapped
auto consumer_tvs = ir_utils::filterByType<TensorView>(expr->outputs());
if (consumer_tvs.size() > 1) {
auto all_consumer_ids = ir_utils::allIDsOf(consumer_tvs.vector().at(0));
info.ordered_sibling_ids.pushBack(
{all_consumer_ids.begin(), all_consumer_ids.end()});
for (const auto i : c10::irange(1, consumer_tvs.size())) {
auto consumer_tv_i = consumer_tvs.vector().at(i);
auto all_consumer_i_ids = ir_utils::allIDsOf(consumer_tv_i);

auto sibling_map =
exact_graph.buildMapBetween(all_consumer_ids, all_consumer_i_ids);

for (const auto& [c_id_1, c_ids] : sibling_map) {
NVF_ERROR(c_ids.size() == 1);
info.sibling_maps[c_id_1->as<IterDomain>()].pushBack(c_ids);
}
}
}
}
return info;
}
Expand All @@ -565,6 +585,17 @@ void IdModel::initializeLoopGraph(const StatefulInliningInfo& info) {
}
}
}

// Similarly maps all sibling domains
for (IterDomain* id : info.ordered_sibling_ids) {
auto entry_it = info.sibling_maps.find(id);
if (entry_it != info.sibling_maps.end()) {
const VectorOfUniqueEntries<Val*>& sibling_ids = entry_it->second;
for (Val* sibling_id : sibling_ids) {
idGraph(IdMappingMode::LOOP).mapVals(id, sibling_id);
}
}
}
}

void IdModel::buildLoopGraph() {
Expand Down
6 changes: 6 additions & 0 deletions csrc/id_model/id_model.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ struct StatefulInliningInfo {
// root domains
std::unordered_map<IterDomain*, VectorOfUniqueEntries<IterDomain*>>
p2c_root_broadcast_resolution_map;

// All IDs of all first siblings
VectorOfUniqueEntries<IterDomain*> ordered_sibling_ids;

// Mappings to other sibling IDs from ordered_sibling_ids
std::unordered_map<IterDomain*, VectorOfUniqueEntries<Val*>> sibling_maps;
};

StatefulInliningInfo buildStatefulInliningInfo(
Expand Down
31 changes: 31 additions & 0 deletions tests/cpp/test_id_model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2347,4 +2347,35 @@ TEST_F(IdModelTest, ValGraphBFS4) {
ASSERT_EQ(tv4_to_tv0, tv4_to_tv0_ref);
}

// Make sure domains of sibling tensors are all mapped together in the
// LOOP graph even when those tensors are not inlined.
TEST_F(IdModelTest, LoopGraphWithSibling) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(3);
fusion.addInput(tv0);
auto welford_out_tvs = Welford(tv0, {2});
auto avg = welford_out_tvs.avg;
fusion.addOutput(avg);

// Random scheduling
avg->split(-1, 4);
avg->merge(0);
avg->split(0, 8);
TransformPropagatorWithCheck propagator(avg);
MaxRootDomainInfoSpanningTree(avg).traverse(&propagator);

IdModel id_model(&fusion);
const auto& loop_graph = id_model.idGraph(IdMappingMode::LOOP);

for (auto welford_out : {welford_out_tvs.var_sum, welford_out_tvs.n}) {
for (const auto i : c10::irange(avg->nDims())) {
ASSERT_TRUE(loop_graph.disjointValSets().strictAreMapped(
avg->axis(i), welford_out->axis(i)))
<< "Unmapped siblings: " << avg->axis(i)->toString() << ", "
<< welford_out->axis(i)->toString();
}
}
}
} // namespace nvfuser