From 373bbd45155f404a7936880afab12626d83e5082 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 8 Dec 2025 15:46:34 -0800 Subject: [PATCH 1/5] Fix detection of ldmatrix/stmatrix --- csrc/device_lower/pass/index.cpp | 41 ++++++++++++++++++++++++-------- csrc/device_lower/pass/index.h | 6 +++-- csrc/id_model/indexing.cpp | 32 +++++++++++++++++++------ csrc/id_model/indexing.h | 6 +++-- csrc/index_compute.cpp | 10 ++++---- csrc/index_compute.h | 6 +++-- 6 files changed, 74 insertions(+), 27 deletions(-) diff --git a/csrc/device_lower/pass/index.cpp b/csrc/device_lower/pass/index.cpp index 2ff58b5141c..0100f6ddbcb 100644 --- a/csrc/device_lower/pass/index.cpp +++ b/csrc/device_lower/pass/index.cpp @@ -40,7 +40,8 @@ Val* IndexLowering::lowerSrcIndex( Val* dst, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) const { + DataType as_type, + bool ld_st_matrix) const { if (auto tv = dynamic_cast(src)) { NVF_ERROR(dst->isA()); kir::TensorIndex* tind = Index::getProducerIndex( @@ -50,7 +51,8 @@ Val* IndexLowering::lowerSrcIndex( getRotatedLoop(), override_index, generate_pointer, - as_type); + as_type, + ld_st_matrix); if (TensorView* aliased_producer = GpuLower::current()->getTensorProducerAlias(tv)) { return IrBuilder::create( @@ -67,7 +69,8 @@ Val* IndexLowering::lowerDstIndex( Val* dst, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) const { + DataType as_type, + bool ld_st_matrix) const { if (auto tv = dynamic_cast(dst)) { return Index::getConsumerIndex( tv, @@ -75,7 +78,8 @@ Val* IndexLowering::lowerDstIndex( getRotatedLoop(), override_index, generate_pointer, - as_type); + as_type, + ld_st_matrix); } else { return dst; } @@ -2047,6 +2051,9 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { return; } + const bool ld_st_matrix = + ir_utils::isLdMatrixOp(ldst) || ir_utils::isStMatrixOp(ldst); + if (ir_utils::isCpAsyncBulk(ldst)) { if (ir_utils::isCpAsyncBulkLoad(ldst)) { handleCpAsyncBulkLoad(ldst); @@ -2096,7 +2103,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { case MmaInputSmemSwizzle::B64: case MmaInputSmemSwizzle::B32: { Val* index = GpuLower::current()->tensorIndexer().getLinearIndex( - in_tv, ldst, for_loops_); + in_tv, + ldst, + for_loops_, + /*override_index=*/{}, + /*ld_st_matrix=*/true); Val* offset = SimplifyingIrBuilder::mulExpr( index, dataTypeSizeByte(in_tv->dtype())); Val* smem_index = @@ -2113,7 +2124,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { static_cast(num_regs)}; // Get the index for the input of stmatrix. - out = lowerDstIndex(ldst->out(), {}, false, as_type); + out = lowerDstIndex(ldst->out(), {}, false, as_type, ld_st_matrix); } else { as_type = ArrayType{ std::make_shared(DataType::UInt32), @@ -2154,7 +2165,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { case MmaInputSmemSwizzle::B64: case MmaInputSmemSwizzle::B32: { Val* index = GpuLower::current()->tensorIndexer().getLinearIndex( - out_tv, ldst, for_loops_); + out_tv, + ldst, + for_loops_, + /*override_index=*/{}, + /*ld_st_matrix=*/true); Val* offset = SimplifyingIrBuilder::mulExpr( index, dataTypeSizeByte(out_tv->dtype())); Val* smem_index = @@ -2171,7 +2186,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { static_cast(num_regs)}; // Get the index for the input of stmatrix. - in = lowerSrcIndex(ldst->in(), ldst->out(), {}, false, as_type); + in = lowerSrcIndex( + ldst->in(), ldst->out(), {}, false, as_type, ld_st_matrix); } else if (ldst->out()->definition()->isA()) { // For MMA accumulator initialization @@ -2208,7 +2224,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { ldst->out(), {}, ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst), - as_type); + as_type, + ld_st_matrix); } if (auto tv = dynamic_cast(ldst->out()); tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) { @@ -2217,7 +2234,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) { tv, index, DataType::TMemAddress); } else { out = lowerDstIndex( - ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type); + ldst->out(), + {}, + ir_utils::isCpAsyncOp(ldst), + as_type, + ld_st_matrix); } } auto new_ldst = diff --git a/csrc/device_lower/pass/index.h b/csrc/device_lower/pass/index.h index d76524377ca..a4e7712c917 100644 --- a/csrc/device_lower/pass/index.h +++ b/csrc/device_lower/pass/index.h @@ -122,13 +122,15 @@ class IndexLowering : private OptOutConstDispatch { Val* dst, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null) const; + DataType as_type = DataType::Null, + bool ld_st_matrix = false) const; Val* lowerDstIndex( Val* dst, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null) const; + DataType as_type = DataType::Null, + bool ld_st_matrix = false) const; void handleCpAsyncBulkLoad(const LoadStoreOp* ldst); void handleCpAsyncBulkStore(const LoadStoreOp* ldst); diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index b3c56493172..eae488c82b4 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -203,7 +203,8 @@ Val* TensorIndexer::getLinearIndex( TensorView* tv, const Expr* expr, const std::vector& for_loops, - const std::unordered_map& override_index) const { + const std::unordered_map& override_index, + bool ld_st_matrix) const { NVF_ERROR(tv != nullptr); NVF_ERROR(expr != nullptr); NVF_ERROR( @@ -223,7 +224,13 @@ Val* TensorIndexer::getLinearIndex( const auto& alloc_info = getIndexAllocationInfo(tv); const auto [contig_indices, contig_strides] = getContigIndexFor( - tv, expr, as_consumer, alloc_info, for_loops, override_index); + tv, + expr, + as_consumer, + alloc_info, + for_loops, + override_index, + ld_st_matrix); // Linearize the indices with strides. Val* linear_index = tv->fusion()->zeroVal(); @@ -922,10 +929,17 @@ void TensorIndexer::ensureStaticIndexing( namespace { // Use alternate loop domain for the shared memory tensor for ldmatrix and -// stmatrix. -bool isSharedMemoryTvForLdStMatrix(TensorView* tv, const Expr* expr) { +// stmatrix. Note that the explicit bool indicator of the expr is +// required to correctly determine it is a ldmatrix/stmatrix op since +// there can be an initialization op using the same output tensor +// after the allocation lowering pass. +bool shouldUseAlternateLoopDomain( + TensorView* tv, + const Expr* expr, + bool ld_st_matrix) { // short-circuit: not (ldmatrix or stmatrix) - if (!ir_utils::isLdMatrixOp(expr) && !ir_utils::isStMatrixOp(expr)) { + if (!(ld_st_matrix && + (ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr)))) { return false; } // short-circuit: only the shared memory TensorView uses alternate loop @@ -960,7 +974,8 @@ std::pair, std::vector> TensorIndexer:: bool as_consumer, const AllocationDomainInfo& alloc_info, const std::vector& for_loops, - const std::unordered_map& override_index) const { + const std::unordered_map& override_index, + bool ld_st_matrix) const { std::vector indexed_ids; indexed_ids.reserve(alloc_info.ids.size()); for (const auto& id : alloc_info.ids) { @@ -969,7 +984,10 @@ std::pair, std::vector> TensorIndexer:: } } auto index_info = computeIndex( - expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr)); + expr, + indexed_ids, + for_loops, + shouldUseAlternateLoopDomain(tv, expr, ld_st_matrix)); for (const auto& [indexed_id, index] : override_index) { index_info.index_map[traversalGraph().toGroup(indexed_id)] = index; } diff --git a/csrc/id_model/indexing.h b/csrc/id_model/indexing.h index a37fde9e66d..94cc988e7a1 100644 --- a/csrc/id_model/indexing.h +++ b/csrc/id_model/indexing.h @@ -71,7 +71,8 @@ class TensorIndexer { TensorView* tv, const Expr* expr, const std::vector& loops, - const std::unordered_map& override_index = {}) const; + const std::unordered_map& override_index = {}, + bool ld_st_matrix = false) const; // Get the index of a loop domain. Val* getLoopIndex( @@ -93,7 +94,8 @@ class TensorIndexer { bool as_consumer, const AllocationDomainInfo& alloc_info, const std::vector& loops, - const std::unordered_map& override_index) const; + const std::unordered_map& override_index, + bool ld_st_matrix = false) const; // Grab all for-loops whose indices are actually used in the given // index vals. Note that IndexingInfo.loop_group_dependencies can be diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index ca3877a2147..909930b9d2e 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2219,12 +2219,13 @@ kir::TensorIndex* Index::getProducerIndex( const std::unordered_set& rotated_loops, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) { + DataType as_type, + bool ld_st_matrix) { Val* index = nullptr; if (shouldUseTensorIndexer(producer, consumer, rotated_loops)) { index = GpuLower::current()->tensorIndexer().getLinearIndex( - producer, consumer->definition(), loops, override_index); + producer, consumer->definition(), loops, override_index, ld_st_matrix); if (generate_pointer) { auto address_offset = index; if (producer->getMemoryType() == MemoryType::Shared) { @@ -2323,7 +2324,8 @@ kir::TensorIndex* Index::getConsumerIndex( const std::unordered_set& rotated_loops, const std::unordered_map& override_index, bool generate_pointer, - DataType as_type) { + DataType as_type, + bool is_st_matrix) { Val* index = nullptr; if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || ir_utils::isCpAsyncBulkLoad(consumer->definition()) || @@ -2331,7 +2333,7 @@ kir::TensorIndex* Index::getConsumerIndex( GpuLower::current()->tmemInfo().hasTMemTensor()) { NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported"); index = GpuLower::current()->tensorIndexer().getLinearIndex( - consumer, consumer->definition(), loops, override_index); + consumer, consumer->definition(), loops, override_index, is_st_matrix); if (generate_pointer) { auto address_offset = index; if (consumer->getMemoryType() == MemoryType::Shared) { diff --git a/csrc/index_compute.h b/csrc/index_compute.h index e3b3116160b..6edfdd287ed 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -496,7 +496,8 @@ class Index { const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null); + DataType as_type = DataType::Null, + bool ld_st_matrix = true); // Consumer index dispatch static kir::TensorIndex* getConsumerIndex( @@ -505,7 +506,8 @@ class Index { const std::unordered_set& rotated_loops, const std::unordered_map& override_index = {}, bool generate_pointer = false, - DataType as_type = DataType::Null); + DataType as_type = DataType::Null, + bool ld_st_matrix = true); //! Returns a vector of strided indices mapped onto the //! allocation domain of a producer tensor. The size of the returned From f5a45a4b1d22074f37c1ef24e3f3d37f2fd107e0 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 8 Dec 2025 18:44:10 -0800 Subject: [PATCH 2/5] cleanup --- csrc/id_model/indexing.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index eae488c82b4..e192caac4d3 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -938,10 +938,12 @@ bool shouldUseAlternateLoopDomain( const Expr* expr, bool ld_st_matrix) { // short-circuit: not (ldmatrix or stmatrix) - if (!(ld_st_matrix && - (ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr)))) { + if (!ld_st_matrix) { return false; } + + NVF_ERROR(ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr)); + // short-circuit: only the shared memory TensorView uses alternate loop // domain. For ldmatrix, it is the input TensorView. For stmatrix, it is the // output TensorView. From 0378c79fadfa535bcbad689260ed87a98dc483ae Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 8 Dec 2025 20:42:44 -0800 Subject: [PATCH 3/5] fix --- csrc/id_model/indexing.cpp | 5 ++++- csrc/index_compute.h | 4 ++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/csrc/id_model/indexing.cpp b/csrc/id_model/indexing.cpp index e192caac4d3..237b8558415 100644 --- a/csrc/id_model/indexing.cpp +++ b/csrc/id_model/indexing.cpp @@ -942,7 +942,10 @@ bool shouldUseAlternateLoopDomain( return false; } - NVF_ERROR(ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr)); + NVF_ERROR( + ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr), + "Unexpected expr: ", + expr->toString()); // short-circuit: only the shared memory TensorView uses alternate loop // domain. For ldmatrix, it is the input TensorView. For stmatrix, it is the diff --git a/csrc/index_compute.h b/csrc/index_compute.h index 6edfdd287ed..34524c97bc3 100644 --- a/csrc/index_compute.h +++ b/csrc/index_compute.h @@ -497,7 +497,7 @@ class Index { const std::unordered_map& override_index = {}, bool generate_pointer = false, DataType as_type = DataType::Null, - bool ld_st_matrix = true); + bool ld_st_matrix = false); // Consumer index dispatch static kir::TensorIndex* getConsumerIndex( @@ -507,7 +507,7 @@ class Index { const std::unordered_map& override_index = {}, bool generate_pointer = false, DataType as_type = DataType::Null, - bool ld_st_matrix = true); + bool ld_st_matrix = false); //! Returns a vector of strided indices mapped onto the //! allocation domain of a producer tensor. The size of the returned From 144536c3e59ca4b5d34cb78ba4fb5b938e6b47ef Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 9 Dec 2025 09:51:19 -0800 Subject: [PATCH 4/5] Apply suggestion from @greptile-apps[bot] Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/index_compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index 909930b9d2e..de002591095 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2325,7 +2325,7 @@ kir::TensorIndex* Index::getConsumerIndex( const std::unordered_map& override_index, bool generate_pointer, DataType as_type, - bool is_st_matrix) { + bool ld_st_matrix) { Val* index = nullptr; if (!ir_utils::hasRootToLoopLinearTransformations(consumer) || ir_utils::isCpAsyncBulkLoad(consumer->definition()) || From cb755ef1d7d59022d725f26ef812e9b6bbb0c647 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 9 Dec 2025 10:01:15 -0800 Subject: [PATCH 5/5] Update csrc/index_compute.cpp Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> --- csrc/index_compute.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/index_compute.cpp b/csrc/index_compute.cpp index de002591095..ec1106909d5 100644 --- a/csrc/index_compute.cpp +++ b/csrc/index_compute.cpp @@ -2333,7 +2333,7 @@ kir::TensorIndex* Index::getConsumerIndex( GpuLower::current()->tmemInfo().hasTMemTensor()) { NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported"); index = GpuLower::current()->tensorIndexer().getLinearIndex( - consumer, consumer->definition(), loops, override_index, is_st_matrix); + consumer, consumer->definition(), loops, override_index, ld_st_matrix); if (generate_pointer) { auto address_offset = index; if (consumer->getMemoryType() == MemoryType::Shared) {