Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,12 @@ void TensorIndexer::buildLoopIndexMap() {
ParallelType ptype = getParallelType(loop_group);
if (isParallelTypeThread(ptype)) {
loop_index = NamedScalar::getParallelIndex(ptype);
} else if (shouldUseZeroIndex(loop_group)) {
} else if (
// TODO: Cleanup needed. ir_utils::isMemoryPartitionedAcross
// should be used, but that means we would need to consider
// multiple outputs with different memory types, though it
// should be uncommon in practice.
shouldUseZeroIndex(loop_group) || isParallelTypeDeviceDim(ptype)) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could isParallelTypeDeviceDim(ptype) go inside shouldUseZeroIndex? If any ID in the group is parallelized DID then the loop must be trivial right?

loop_index = fusion->zeroVal();
} else {
loop_index = IrBuilder::create<Val>(DataType::Index);
Expand Down
203 changes: 203 additions & 0 deletions tests/cpp/test_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,209 @@ TEST_F(IndexingTest, SimpleBroadcast4) {
EXPECT_EQ(tv2_producer_index, tv4_loop_indices.at(1));
}

// Trivial example. 1D shared tensor. Each device only has one
// element, so the index should be always just zero.
TEST_F(IndexingTest, MultiDevice1DNoSplitMerge) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeSymbolicTensor(1);
fusion.addInput(tv0);

auto tv1 = set(tv0);
fusion.addOutput(tv1);

tv0->axis(0)->parallelize(ParallelType::DIDx);
tv1->axis(0)->parallelize(ParallelType::DIDx);

IdModel id_model(&fusion);
TensorIndexer indexer(id_model);

EXPECT_TRUE(indexer.getLinearIndex(tv0, tv1->definition())->isZeroInt());
EXPECT_TRUE(indexer.getLinearIndex(tv1, tv1->definition())->isZeroInt());
}

// Same fusion as MultiDevice1DNoSplitMerge but with split.
TEST_F(IndexingTest, MultiDevice1DSplit) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(1);
fusion.addInput(tv0);

auto tv1 = set(tv0);
fusion.addOutput(tv1);

Val* num_devices = IrBuilder::create<Val>(DataType::Index);

tv0->split(0, num_devices, false);
tv1->split(0, num_devices, false);

tv0->axis(0)->parallelize(ParallelType::DIDx);
tv1->axis(0)->parallelize(ParallelType::DIDx);

IdModel id_model(&fusion);
TensorIndexer indexer(id_model);

std::vector<Val*> tv1_loop_indices = getLoopIndices(tv1, indexer);

auto tv0_producer_index = indexer.getLinearIndex(tv0, tv1->definition());
auto tv1_consumer_index = indexer.getLinearIndex(tv1, tv1->definition());

auto tv0_producer_index_ref = tv1_loop_indices.at(1);
auto tv1_consumer_index_ref = tv1_loop_indices.at(1);

EXPECT_TRUE(tv0_producer_index->sameAs(tv0_producer_index_ref))
<< "Ref: " << tv0_producer_index_ref->toInlineString()
<< ". Actual: " << tv0_producer_index->toInlineString();

EXPECT_TRUE(tv1_consumer_index->sameAs(tv1_consumer_index_ref))
<< "Ref: " << tv1_consumer_index_ref->toInlineString()
<< ". Actual: " << tv1_consumer_index->toInlineString();
}

TEST_F(IndexingTest, MultiDevice2D) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);

auto tv1 = set(tv0);
fusion.addOutput(tv1);

Val* num_devices = IrBuilder::create<Val>(DataType::Index);

tv1->flatten();
tv1->split(0, num_devices, false);

TransformPropagator propagator(tv1);
MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);

tv0->axis(0)->parallelize(ParallelType::DIDx);
tv1->axis(0)->parallelize(ParallelType::DIDx);

IdModel id_model(&fusion);
TensorIndexer indexer(id_model);

std::vector<Val*> tv1_loop_indices = getLoopIndices(tv1, indexer);

auto tv0_producer_index = indexer.getLinearIndex(tv0, tv1->definition());
auto tv1_consumer_index = indexer.getLinearIndex(tv1, tv1->definition());

auto inner_dim = tv1->getLogicalDomain().at(1)->extent();

// Note that the allocation domain is the logical domain. See the
// next test for a leaf allocation example
auto tv0_producer_index_ref = addExpr(
modExpr(tv1_loop_indices.at(1), inner_dim),
mulExpr(divExpr(tv1_loop_indices.at(1), inner_dim), inner_dim));

// Should use the same index
auto tv1_consumer_index_ref = tv0_producer_index_ref;

EXPECT_TRUE(tv0_producer_index->sameAs(tv0_producer_index_ref))
<< "Ref: " << tv0_producer_index_ref->toInlineString()
<< ". Actual: " << tv0_producer_index->toInlineString();

EXPECT_TRUE(tv1_consumer_index->sameAs(tv1_consumer_index_ref))
<< "Ref: " << tv1_consumer_index_ref->toInlineString()
<< ". Actual: " << tv1_consumer_index->toInlineString();
}

// Same fusion as MultiDevice2D but with leaf allocation
TEST_F(IndexingTest, MultiDevice2DLeafAllocation) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);

auto tv1 = set(tv0);
fusion.addOutput(tv1);

Val* num_devices = IrBuilder::create<Val>(DataType::Index);

tv1->flatten();
tv1->split(0, num_devices, false);

TransformPropagator propagator(tv1);
MaxRootDomainInfoSpanningTree(tv1).traverse(&propagator);

tv0->axis(0)->parallelize(ParallelType::DIDx);
tv1->axis(0)->parallelize(ParallelType::DIDx);

tv0->setAllocationDomain(tv0->getLeafDomain(), true);
tv1->setAllocationDomain(tv1->getLeafDomain(), true);

IdModel id_model(&fusion);
TensorIndexer indexer(id_model);

std::vector<Val*> tv1_loop_indices = getLoopIndices(tv1, indexer);

auto tv0_producer_index = indexer.getLinearIndex(tv0, tv1->definition());
auto tv1_consumer_index = indexer.getLinearIndex(tv1, tv1->definition());

// Since the leaf domain is the allocation domain, the index should
// be just the non-parallelized loop index
auto tv0_producer_index_ref = tv1_loop_indices.at(1);

// Should use the same index
auto tv1_consumer_index_ref = tv0_producer_index_ref;

EXPECT_TRUE(tv0_producer_index->sameAs(tv0_producer_index_ref))
<< "Ref: " << tv0_producer_index_ref->toInlineString()
<< ". Actual: " << tv0_producer_index->toInlineString();

EXPECT_TRUE(tv1_consumer_index->sameAs(tv1_consumer_index_ref))
<< "Ref: " << tv1_consumer_index_ref->toInlineString()
<< ". Actual: " << tv1_consumer_index->toInlineString();
}

TEST_F(IndexingTest, MultiDevice2DTranspose) {
Fusion fusion;
FusionGuard fg(&fusion);

auto tv0 = makeContigTensor(2);
fusion.addInput(tv0);

auto tv1 = transpose(tv0);
fusion.addOutput(tv1);

Val* num_devices = IrBuilder::create<Val>(DataType::Index);

tv0->split(0, num_devices, false);
tv1->split(0, num_devices, false);

tv0->axis(0)->parallelize(ParallelType::DIDx);
tv1->axis(0)->parallelize(ParallelType::DIDx);

IdModel id_model(&fusion);
TensorIndexer indexer(id_model);

std::vector<Val*> tv1_loop_indices = getLoopIndices(tv1, indexer);

auto tv0_producer_index = indexer.getLinearIndex(tv0, tv1->definition());
auto tv1_consumer_index = indexer.getLinearIndex(tv1, tv1->definition());

auto tv0_producer_index_ref = addExpr(
tv1_loop_indices.at(1),
mulExpr(tv1_loop_indices.at(2), tv0->getLogicalDomain().at(1)->extent()));

// Should use the same index
auto tv1_consumer_index_ref = addExpr(
tv1_loop_indices.at(2),
mulExpr(tv1_loop_indices.at(1), tv1->getLogicalDomain().at(1)->extent()));

EXPECT_TRUE(tv0_producer_index->sameAs(tv0_producer_index_ref))
<< "Ref: " << tv0_producer_index_ref->toInlineString()
<< ". Actual: " << tv0_producer_index->toInlineString();

EXPECT_TRUE(tv1_consumer_index->sameAs(tv1_consumer_index_ref))
<< "Ref: " << tv1_consumer_index_ref->toInlineString()
<< ". Actual: " << tv1_consumer_index->toInlineString();
}

// Allocation of broadcast domains should not need to be promoted.
TEST_F(IndexingTest, PromotedBroadcast) {
Fusion fusion;
Expand Down