Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 31 additions & 10 deletions csrc/device_lower/pass/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,8 @@ Val* IndexLowering::lowerSrcIndex(
Val* dst,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool generate_pointer,
DataType as_type) const {
DataType as_type,
bool ld_st_matrix) const {
if (auto tv = dynamic_cast<TensorView*>(src)) {
NVF_ERROR(dst->isA<TensorView>());
kir::TensorIndex* tind = Index::getProducerIndex(
Expand All @@ -50,7 +51,8 @@ Val* IndexLowering::lowerSrcIndex(
getRotatedLoop(),
override_index,
generate_pointer,
as_type);
as_type,
ld_st_matrix);
if (TensorView* aliased_producer =
GpuLower::current()->getTensorProducerAlias(tv)) {
return IrBuilder::create<kir::TensorIndex>(
Expand All @@ -67,15 +69,17 @@ Val* IndexLowering::lowerDstIndex(
Val* dst,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool generate_pointer,
DataType as_type) const {
DataType as_type,
bool ld_st_matrix) const {
if (auto tv = dynamic_cast<TensorView*>(dst)) {
return Index::getConsumerIndex(
tv,
for_loops_,
getRotatedLoop(),
override_index,
generate_pointer,
as_type);
as_type,
ld_st_matrix);
} else {
return dst;
}
Expand Down Expand Up @@ -2047,6 +2051,9 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
return;
}

const bool ld_st_matrix =
ir_utils::isLdMatrixOp(ldst) || ir_utils::isStMatrixOp(ldst);

if (ir_utils::isCpAsyncBulk(ldst)) {
if (ir_utils::isCpAsyncBulkLoad(ldst)) {
handleCpAsyncBulkLoad(ldst);
Expand Down Expand Up @@ -2096,7 +2103,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
case MmaInputSmemSwizzle::B64:
case MmaInputSmemSwizzle::B32: {
Val* index = GpuLower::current()->tensorIndexer().getLinearIndex(
in_tv, ldst, for_loops_);
in_tv,
ldst,
for_loops_,
/*override_index=*/{},
/*ld_st_matrix=*/true);
Val* offset = SimplifyingIrBuilder::mulExpr(
index, dataTypeSizeByte(in_tv->dtype()));
Val* smem_index =
Expand All @@ -2113,7 +2124,7 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
static_cast<size_t>(num_regs)};

// Get the index for the input of stmatrix.
out = lowerDstIndex(ldst->out(), {}, false, as_type);
out = lowerDstIndex(ldst->out(), {}, false, as_type, ld_st_matrix);
} else {
as_type = ArrayType{
std::make_shared<DataType>(DataType::UInt32),
Expand Down Expand Up @@ -2154,7 +2165,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
case MmaInputSmemSwizzle::B64:
case MmaInputSmemSwizzle::B32: {
Val* index = GpuLower::current()->tensorIndexer().getLinearIndex(
out_tv, ldst, for_loops_);
out_tv,
ldst,
for_loops_,
/*override_index=*/{},
/*ld_st_matrix=*/true);
Val* offset = SimplifyingIrBuilder::mulExpr(
index, dataTypeSizeByte(out_tv->dtype()));
Val* smem_index =
Expand All @@ -2171,7 +2186,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
static_cast<size_t>(num_regs)};

// Get the index for the input of stmatrix.
in = lowerSrcIndex(ldst->in(), ldst->out(), {}, false, as_type);
in = lowerSrcIndex(
ldst->in(), ldst->out(), {}, false, as_type, ld_st_matrix);

} else if (ldst->out()->definition()->isA<MmaOp>()) {
// For MMA accumulator initialization
Expand Down Expand Up @@ -2208,7 +2224,8 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
ldst->out(),
{},
ir_utils::isLdMatrixOp(ldst) || ir_utils::isCpAsyncOp(ldst),
as_type);
as_type,
ld_st_matrix);
}
if (auto tv = dynamic_cast<TensorView*>(ldst->out());
tv != nullptr && tv->getMemoryType() == MemoryType::Tensor) {
Expand All @@ -2217,7 +2234,11 @@ void IndexLowering::handle(const LoadStoreOp* ldst) {
tv, index, DataType::TMemAddress);
} else {
out = lowerDstIndex(
ldst->out(), {}, ir_utils::isCpAsyncOp(ldst), as_type);
ldst->out(),
{},
ir_utils::isCpAsyncOp(ldst),
as_type,
ld_st_matrix);
}
}
auto new_ldst =
Expand Down
6 changes: 4 additions & 2 deletions csrc/device_lower/pass/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,13 +122,15 @@ class IndexLowering : private OptOutConstDispatch {
Val* dst,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false,
DataType as_type = DataType::Null) const;
DataType as_type = DataType::Null,
bool ld_st_matrix = false) const;

Val* lowerDstIndex(
Val* dst,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false,
DataType as_type = DataType::Null) const;
DataType as_type = DataType::Null,
bool ld_st_matrix = false) const;

void handleCpAsyncBulkLoad(const LoadStoreOp* ldst);
void handleCpAsyncBulkStore(const LoadStoreOp* ldst);
Expand Down
37 changes: 30 additions & 7 deletions csrc/id_model/indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ Val* TensorIndexer::getLinearIndex(
TensorView* tv,
const Expr* expr,
const std::vector<kir::ForLoop*>& for_loops,
const std::unordered_map<IterDomain*, Val*>& override_index) const {
const std::unordered_map<IterDomain*, Val*>& override_index,
bool ld_st_matrix) const {
NVF_ERROR(tv != nullptr);
NVF_ERROR(expr != nullptr);
NVF_ERROR(
Expand All @@ -223,7 +224,13 @@ Val* TensorIndexer::getLinearIndex(
const auto& alloc_info = getIndexAllocationInfo(tv);

const auto [contig_indices, contig_strides] = getContigIndexFor(
tv, expr, as_consumer, alloc_info, for_loops, override_index);
tv,
expr,
as_consumer,
alloc_info,
for_loops,
override_index,
ld_st_matrix);

// Linearize the indices with strides.
Val* linear_index = tv->fusion()->zeroVal();
Expand Down Expand Up @@ -922,12 +929,24 @@ void TensorIndexer::ensureStaticIndexing(
namespace {

// Use alternate loop domain for the shared memory tensor for ldmatrix and
// stmatrix.
bool isSharedMemoryTvForLdStMatrix(TensorView* tv, const Expr* expr) {
// stmatrix. Note that the explicit bool indicator of the expr is
// required to correctly determine it is a ldmatrix/stmatrix op since
// there can be an initialization op using the same output tensor
// after the allocation lowering pass.
bool shouldUseAlternateLoopDomain(
Comment on lines +932 to +936
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit ugly fix, but at the time of index lowering, the expr parameter doesn't necessarily mean it's the actual expression for the lowered operation. The state of the fusion program is not well defined here as we are still building the Kernel IR program. After the allocation lowering, a TensorView can have multiple defining expressions due to, e.g., initializations of buffers, and thus it's no longer SSA. What tv->definition() returns is the original expr, but we may be using it even when lowering the initialization.

That could cause a problem here since even though expr is a ldmatrix or stmatrix, it may not correspond to the actual op such as initializations. To find if the actual op is indeed ldmatrix or stmatrix, that information needs to be passed down from the indexing pass itself.

TensorView* tv,
const Expr* expr,
bool ld_st_matrix) {
// short-circuit: not (ldmatrix or stmatrix)
if (!ir_utils::isLdMatrixOp(expr) && !ir_utils::isStMatrixOp(expr)) {
if (!ld_st_matrix) {
return false;
}

NVF_ERROR(
ir_utils::isLdMatrixOp(expr) || ir_utils::isStMatrixOp(expr),
"Unexpected expr: ",
expr->toString());

// short-circuit: only the shared memory TensorView uses alternate loop
// domain. For ldmatrix, it is the input TensorView. For stmatrix, it is the
// output TensorView.
Expand Down Expand Up @@ -960,7 +979,8 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
bool as_consumer,
const AllocationDomainInfo& alloc_info,
const std::vector<kir::ForLoop*>& for_loops,
const std::unordered_map<IterDomain*, Val*>& override_index) const {
const std::unordered_map<IterDomain*, Val*>& override_index,
bool ld_st_matrix) const {
std::vector<IterDomain*> indexed_ids;
indexed_ids.reserve(alloc_info.ids.size());
for (const auto& id : alloc_info.ids) {
Expand All @@ -969,7 +989,10 @@ std::pair<std::vector<Val*>, std::vector<Val*>> TensorIndexer::
}
}
auto index_info = computeIndex(
expr, indexed_ids, for_loops, isSharedMemoryTvForLdStMatrix(tv, expr));
expr,
indexed_ids,
for_loops,
shouldUseAlternateLoopDomain(tv, expr, ld_st_matrix));
for (const auto& [indexed_id, index] : override_index) {
index_info.index_map[traversalGraph().toGroup(indexed_id)] = index;
}
Expand Down
6 changes: 4 additions & 2 deletions csrc/id_model/indexing.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,8 @@ class TensorIndexer {
TensorView* tv,
const Expr* expr,
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {}) const;
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool ld_st_matrix = false) const;

// Get the index of a loop domain.
Val* getLoopIndex(
Expand All @@ -93,7 +94,8 @@ class TensorIndexer {
bool as_consumer,
const AllocationDomainInfo& alloc_info,
const std::vector<kir::ForLoop*>& loops,
const std::unordered_map<IterDomain*, Val*>& override_index) const;
const std::unordered_map<IterDomain*, Val*>& override_index,
bool ld_st_matrix = false) const;

// Grab all for-loops whose indices are actually used in the given
// index vals. Note that IndexingInfo.loop_group_dependencies can be
Expand Down
10 changes: 6 additions & 4 deletions csrc/index_compute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2219,12 +2219,13 @@ kir::TensorIndex* Index::getProducerIndex(
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool generate_pointer,
DataType as_type) {
DataType as_type,
bool ld_st_matrix) {
Val* index = nullptr;

if (shouldUseTensorIndexer(producer, consumer, rotated_loops)) {
index = GpuLower::current()->tensorIndexer().getLinearIndex(
producer, consumer->definition(), loops, override_index);
producer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (producer->getMemoryType() == MemoryType::Shared) {
Expand Down Expand Up @@ -2323,15 +2324,16 @@ kir::TensorIndex* Index::getConsumerIndex(
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index,
bool generate_pointer,
DataType as_type) {
DataType as_type,
bool ld_st_matrix) {
Val* index = nullptr;
if (!ir_utils::hasRootToLoopLinearTransformations(consumer) ||
ir_utils::isCpAsyncBulkLoad(consumer->definition()) ||
GpuLower::current()->idModelOptions().consumerIndex() ||
GpuLower::current()->tmemInfo().hasTMemTensor()) {
NVF_ERROR(rotated_loops.empty(), "Loop rotation is not supported");
index = GpuLower::current()->tensorIndexer().getLinearIndex(
consumer, consumer->definition(), loops, override_index);
consumer, consumer->definition(), loops, override_index, ld_st_matrix);
if (generate_pointer) {
auto address_offset = index;
if (consumer->getMemoryType() == MemoryType::Shared) {
Expand Down
6 changes: 4 additions & 2 deletions csrc/index_compute.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,8 @@ class Index {
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false,
DataType as_type = DataType::Null);
DataType as_type = DataType::Null,
bool ld_st_matrix = false);

// Consumer index dispatch
static kir::TensorIndex* getConsumerIndex(
Expand All @@ -505,7 +506,8 @@ class Index {
const std::unordered_set<kir::ForLoop*>& rotated_loops,
const std::unordered_map<IterDomain*, Val*>& override_index = {},
bool generate_pointer = false,
DataType as_type = DataType::Null);
DataType as_type = DataType::Null,
bool ld_st_matrix = false);

//! Returns a vector of strided indices mapped onto the
//! allocation domain of a producer tensor. The size of the returned
Expand Down