Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1055,13 +1055,45 @@ class TransformationIntroducesPaddingError : public ScheduleError {
PrimExpr padding_predicate_;
};

// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
// dtype-mismatch issues later.
IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
const auto& initial_indices_orig = index_map->initial_indices;
ICHECK(buf->shape.size() == initial_indices_orig.size());

Array<Var> initial_indices;
Map<Var, PrimExpr> var_map;

for (size_t i = 0; i < buf->shape.size(); ++i) {
if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) {
auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype);
initial_indices.push_back(new_idx);
var_map.Set(initial_indices_orig[i], new_idx);
} else {
initial_indices.push_back(initial_indices_orig[i]);
}
}

if (!var_map.empty()) {
auto final_indices = index_map->final_indices.Map([&](PrimExpr index) {
return SubstituteWithDataTypeLegalization(index,
[&](const Var& var) { return var_map.Get(var); });
});
return IndexMap(initial_indices, final_indices);
}
return index_map;
}

void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index,
BufferIndexType buffer_index_type, const IndexMap& index_map,
BufferIndexType buffer_index_type, const IndexMap& index_map_orig,
const Optional<IndexMap>& pad_value) {
// Step 1: Input handling and error checking
const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref);
Buffer old_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);

auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer);

auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer);
if (defining_site_sref.defined() && !is_alloc) {
throw BufferIsSubregionError(self->mod, old_buffer);
Expand Down
21 changes: 21 additions & 0 deletions tests/python/unittest/test_tir_schedule_transform_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -925,5 +925,26 @@ def expected(a: T.handle):
A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32")


def test_index_map_dtype_legalize():
"""Test dtype legalization of the index map indices."""

@T.prim_func
def func(A: T.Buffer[T.int64(58), "int32"]):
for i in T.serial(T.int64(58)):
with T.block("block"):
vi = T.axis.remap("S", [i])
T.writes(A[vi])
A[vi] = 0

sch = tir.Schedule(func)

# # The following error is raised from the IterVar constructor without the dtype legalization.
# # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) :
# # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32)
sch.transform_layout(
sch.get_block("block"), buffer="A", index_map=lambda h: [h // 8, h % 8], pad_value=0
)


if __name__ == "__main__":
tvm.testing.main()