diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index c0b4ddfb4ac3..bf618af8de54 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1055,13 +1055,45 @@ class TransformationIntroducesPaddingError : public ScheduleError { PrimExpr padding_predicate_; }; +// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid +// dtype-mismatch issues later. +IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { + const auto& initial_indices_orig = index_map->initial_indices; + ICHECK(buf->shape.size() == initial_indices_orig.size()); + + Array initial_indices; + Map var_map; + + for (size_t i = 0; i < buf->shape.size(); ++i) { + if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) { + auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype); + initial_indices.push_back(new_idx); + var_map.Set(initial_indices_orig[i], new_idx); + } else { + initial_indices.push_back(initial_indices_orig[i]); + } + } + + if (!var_map.empty()) { + auto final_indices = index_map->final_indices.Map([&](PrimExpr index) { + return SubstituteWithDataTypeLegalization(index, + [&](const Var& var) { return var_map.Get(var); }); + }); + return IndexMap(initial_indices, final_indices); + } + return index_map; +} + void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map, + BufferIndexType buffer_index_type, const IndexMap& index_map_orig, const Optional& pad_value) { // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + + auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer); + auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e90478922324..35b2dd53b80e 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -925,5 +925,26 @@ def expected(a: T.handle): A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32") +def test_index_map_dtype_legalize(): + """Test dtype legalization of the index map indices.""" + + @T.prim_func + def func(A: T.Buffer[T.int64(58), "int32"]): + for i in T.serial(T.int64(58)): + with T.block("block"): + vi = T.axis.remap("S", [i]) + T.writes(A[vi]) + A[vi] = 0 + + sch = tir.Schedule(func) + + # # The following error is raised from the IterVar constructor without the dtype legalization. + # # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : + # # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) + sch.transform_layout( + sch.get_block("block"), buffer="A", index_map=lambda h: [h // 8, h % 8], pad_value=0 + ) + + if __name__ == "__main__": tvm.testing.main()