diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 2ff58b5141c..0100f6ddbcb 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -40,7 +40,8 @@ Val* IndexLowering::lowerSrcIndex( Val* dst, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) const { + DataType as_type, + bool ld_st_matrix) const { if (auto tv = dynamic_cast(src)) { NVF_ERROR(dst->isA()); kir::TensorIndex* tind = Index::getProducerIndex( @@ -50,7 +51,8 @@ Val* IndexLowering::lowerSrcIndex( getRotatedLoop(), override_index, generate_pointer, - as_type); + as_type, + ld_st_matrix); if (TensorView* aliased_producer = GpuLower::current()->getTensorProducerAlias(tv)) { return IrBuilder::create( @@ -67,7 +69,8 @@ Val* IndexLowering::lowerDstIndex( Val* dst, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) const { + DataType as_type, + bool ld_st_matrix) const { if (auto tv = dynamic_cast(dst)) { return Index::getConsumerIndex( tv, @@ -75,7 +78,8 @@ Val* IndexLowering::lowerDstIndex( getRotatedLoop(), override_index, generate_pointer, - as_type); + as_type, + ld_st_matrix); } else { return dst; } @@ -2047,6 +2051,9 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { return; } + const bool ld_st_matrix = + ir_utils::isLdMatrixOp(ldst) || ir_utils::isStMatrixOp(ldst); + if (ir_utils::isCpAsyncBulk(ldst)) { if (ir_utils::isCpAsyncBulkLoad(ldst)) { handleCpAsyncBulkLoad(ldst); @@ -2096,7 +2103,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { case MmaInputSmemSwizzle::B64: case MmaInputSmemSwizzle::B32: { Val* index = GpuLower::current()->tensorIndexer().getLinearIndex( - in_tv, ldst, for_loops_); + in_tv, + ldst, + for_loops_, + /*override_index=*/{}, + /*ld_st_matrix=*/true); Val* offset = SimplifyingIrBuilder::mulExpr( index, dataTypeSizeByte(in_tv->dtype())); Val* smem_index = @@ -2113,7 +2124,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { static_cast(num_regs)}; // Get the index for the input of stmatrix. - out = lowerDstIndex(ldst->out(), {}, false, as_type); + out = lowerDstIndex(ldst->out(), {}, false, as_type, ld_st_matrix); } else { as_type = ArrayType{ std::make_shared(DataType::UInt32), @@ -2154,7 +2165,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { case MmaInputSmemSwizzle::B64: case MmaInputSmemSwizzle::B32: { Val* index = GpuLower::current()->tensorIndexer().getLinearIndex( - out_tv, ldst, for_loops_); + out_tv, + ldst, + for_loops_, + /*override_index=*/{}, + /*ld_st_matrix=*/true); Val* offset = SimplifyingIrBuilder::mulExpr( index, dataTypeSizeByte(out_tv->dtype())); Val* smem_index = @@ -2171,7 +2186,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { static_cast(num_regs)}; // Get the index for the input of stmatrix. - in = lowerSrcIndex(ldst->in(), ldst->out(), {}, false, as_type); + in = lowerSrcIndex( + ldst->in(), ldst->out(), {}, false, as_type, ld_st_matrix); } else if (ldst->out()->definition()->isA()) { // For MMA accumulator initialization @@ -2208,7 +2224,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { ldst->out(), {}, ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst), - as_type); + as_type, + ld_st_matrix); } if (auto tv = dynamic_cast(ldst->out()); tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) { @@ -2217,7 +2234,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { tv, index, DataType::TMemAddress); } else { out = lowerDstIndex( - ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + ldst->out(), + {}, + ir_utils::isCpAsyncOp(ldst), + as_type, + ld_st_matrix); } } auto new_ldst = diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index d76524377ca..a4e7712c917 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -122,13 +122,15 @@ class IndexLowering : private OptOutConstDispatch { Val* dst, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null) const; + DataType as_type = DataType::Null, + bool ld_st_matrix = false) const; Val* lowerDstIndex( Val* dst, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null) const; + DataType as_type = DataType::Null, + bool ld_st_matrix = false) const; void handleCpAsyncBulkLoad(const LoadStoreOp* ldst); void handleCpAsyncBulkStore(const LoadStoreOp* ldst); diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index b3c56493172..237b8558415 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -203,7 +203,8 @@ Val* TensorIndexer::getLinearIndex( TensorView* tv, const Expr* expr, const std::vector& for_loops, - const std::unordered_map& override_index) const { + const std::unordered_map& override_index, + bool ld_st_matrix) const { NVF_ERROR(tv != nullptr); NVF_ERROR(expr != nullptr); NVF_ERROR( @@ -223,7 +224,13 @@ Val* TensorIndexer::getLinearIndex( const auto& alloc_info = getIndexAllocationInfo(tv); const auto [contig_indices, contig_strides] = getContigIndexFor( - tv, expr, as_consumer, alloc_info, for_loops, override_index); + tv, + expr, + as_consumer, + alloc_info, + for_loops, + override_index, + ld_st_matrix); // Linearize the indices with strides. Val* linear_index = tv->fusion()->zeroVal(); @@ -922,12 +929,24 @@ void TensorIndexer::ensureStaticIndexing( namespace { // Use alternate loop domain for the shared memory tensor for ldmatrix and -// stmatrix. -bool isSharedMemoryTvForLdStMatrix(TensorView* tv, const Expr* expr) { +// stmatrix. Note that the explicit bool indicator of the expr is +// required to correctly determine it is a ldmatrix/stmatrix op since +// there can be an initialization op using the same output tensor +// after the allocation lowering pass. +bool shouldUseAlternateLoopDomain( + TensorView* tv, + const Expr* expr, + bool ld_st_matrix) { // short-circuit: not (ldmatrix or stmatrix) - if (!ir_utils::isLdMatrixOp(expr) && !ir_utils::isStMatrixOp(expr)) { + if (!ld_st_matrix) { return false; } + + NVF_ERROR( + ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr), + "Unexpected expr: ", + expr->toString()); + // short-circuit: only the shared memory TensorView uses alternate loop // domain. For ldmatrix, it is the input TensorView. For stmatrix, it is the // output TensorView. @@ -960,7 +979,8 @@ std::pair, std::vector> TensorIndexer:: bool as_consumer, const AllocationDomainInfo& alloc_info, const std::vector& for_loops, - const std::unordered_map& override_index) const { + const std::unordered_map& override_index, + bool ld_st_matrix) const { std::vector indexed_ids; indexed_ids.reserve(alloc_info.ids.size()); for (const auto& id : alloc_info.ids) { @@ -969,7 +989,10 @@ std::pair, std::vector> TensorIndexer:: } } auto index_info = computeIndex( - expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr)); + expr, + indexed_ids, + for_loops, + shouldUseAlternateLoopDomain(tv, expr, ld_st_matrix)); for (const auto& [indexed_id, index] : override_index) { index_info.index_map[traversalGraph().toGroup(indexed_id)] = index; } diff --git a/csrc/id_model/indexing.h b/csrc/id_model/indexing.h index a37fde9e66d..94cc988e7a1 100644 --- a/csrc/id_model/indexing.h +++ b/csrc/id_model/indexing.h @@ -71,7 +71,8 @@ class TensorIndexer { TensorView* tv, const Expr* expr, const std::vector& loops, - const std::unordered_map& override_index = {}) const; + const std::unordered_map& override_index = {}, + bool ld_st_matrix = false) const; // Get the index of a loop domain. Val* getLoopIndex( @@ -93,7 +94,8 @@ class TensorIndexer { bool as_consumer, const AllocationDomainInfo& alloc_info, const std::vector& loops, - const std::unordered_map& override_index) const; + const std::unordered_map& override_index, + bool ld_st_matrix = false) const; // Grab all for-loops whose indices are actually used in the given // index vals. Note that IndexingInfo.loop_group_dependencies can be diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index ca3877a2147..ec1106909d5 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2219,12 +2219,13 @@ kir::TensorIndex* Index::getProducerIndex( const std::unordered_set& rotated_loops, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) { + DataType as_type, + bool ld_st_matrix) { Val* index = nullptr; if (shouldUseTensorIndexer(producer, consumer, rotated_loops)) { index = GpuLower::current()->tensorIndexer().getLinearIndex( - producer, consumer->definition(), loops, override_index); + producer, consumer->definition(), loops, override_index, ld_st_matrix); if (generate_pointer) { auto address_offset = index; if (producer->getMemoryType() == MemoryType::Shared) { @@ -2323,7 +2324,8 @@ kir::TensorIndex* Index::getConsumerIndex( const std::unordered_set& rotated_loops, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) { + DataType as_type, + bool ld_st_matrix) { Val* index = nullptr; if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || ir_utils::isCpAsyncBulkLoad(consumer->definition()) || @@ -2331,7 +2333,7 @@ kir::TensorIndex* Index::getConsumerIndex( GpuLower::current()->tmemInfo().hasTMemTensor()) { NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported"); index = GpuLower::current()->tensorIndexer().getLinearIndex( - consumer, consumer->definition(), loops, override_index); + consumer, consumer->definition(), loops, override_index, ld_st_matrix); if (generate_pointer) { auto address_offset = index; if (consumer->getMemoryType() == MemoryType::Shared) { diff --git a/csrc/index_compute.h b/csrc/index_compute.h index e3b3116160b..34524c97bc3 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -496,7 +496,8 @@ class Index { const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null); + DataType as_type = DataType::Null, + bool ld_st_matrix = false); // Consumer index dispatch static kir::TensorIndex* getConsumerIndex( @@ -505,7 +506,8 @@ class Index { const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null); + DataType as_type = DataType::Null, + bool ld_st_matrix = false); //! Returns a vector of strided indices mapped onto the //! allocation domain of a producer tensor. The size of the returned