diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index bbbbd2fdf56f..a9b367c4b7d9 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1056,16 +1056,16 @@ class TransformationIntroducesPaddingError : public ScheduleError { // Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid // dtype-mismatch issues later. -IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { +IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array& args) { const auto& initial_indices_orig = index_map->initial_indices; - ICHECK(buf->shape.size() == initial_indices_orig.size()); + ICHECK(args.size() == initial_indices_orig.size()); Array initial_indices; Map var_map; - for (size_t i = 0; i < buf->shape.size(); ++i) { - if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) { - auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype); + for (size_t i = 0; i < args.size(); ++i) { + if (args[i]->dtype != initial_indices_orig[i].dtype()) { + auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype); initial_indices.push_back(new_idx); var_map.Set(initial_indices_orig[i], new_idx); } else { @@ -1078,7 +1078,12 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { return SubstituteWithDataTypeLegalization(index, [&](const Var& var) { return var_map.Get(var); }); }); - return IndexMap(initial_indices, final_indices); + Optional opt_inverse_index_map = + Downcast>(index_map->inverse_index_map); + if (opt_inverse_index_map.defined()) { + opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices); + } + return IndexMap(initial_indices, final_indices, opt_inverse_index_map); } return index_map; } @@ -1091,7 +1096,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_ Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); - auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer); + auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape); auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { diff --git a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py index 98c1f7368580..80ca954cca5c 100644 --- a/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py +++ b/tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py @@ -480,5 +480,135 @@ def test_layout_rewrite_cache_read_multiple(): tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten) +class TestLayoutRewriteInt64Index(BaseBeforeAfter): + def before( + p0: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"], + p1: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"], + T_batch_matmul_NT: T.Buffer[(T.int64(12), T.int64(197), T.int64(197)), "int32"], + ): + T.func_attr({"layout_free_buffers": [1], "global_symbol": "main", "tir.noalias": True}) + for b_0_i_0_fused in T.parallel(T.int64(394)): + for j_0 in T.serial(T.int64(1)): + for b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1)): + for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid( + T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1) + ): + with T.block("T_batch_matmul_NT_init"): + v_b = T.axis.spatial( + T.int64(12), + b_3_init + + b_0_i_0_fused // T.int64(197) * T.int64(6) + + b_1 * T.int64(6) + + b_2_init, + ) + v_i = T.axis.spatial( + T.int64(197), + b_0_i_0_fused % T.int64(197) + i_1 + i_2_init + i_3_init, + ) + v_j = T.axis.spatial( + T.int64(197), + j_3_init + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2_init, + ) + T_batch_matmul_NT[v_b, v_i, v_j] = 0 + for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid( + T.int64(64), + T.int64(6), + T.int64(1), + T.int64(197), + T.int64(1), + T.int64(1), + T.int64(1), + T.int64(1), + ): + with T.block("T_batch_matmul_NT_update"): + v_b = T.axis.spatial( + T.int64(12), + b_3 + + b_0_i_0_fused // T.int64(197) * T.int64(6) + + b_1 * T.int64(6) + + b_2, + ) + v_i = T.axis.spatial( + T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2 + i_3 + ) + v_j = T.axis.spatial( + T.int64(197), j_3 + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2 + ) + v_k = T.axis.reduce(T.int64(64), k_0 + k_1) + T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[ + v_b, v_i, v_j + ] + T.Cast("int32", p0[v_b, v_i, v_k]) * T.Cast( + "int32", p1[v_b, v_j, v_k] + ) + + def expected( + p0: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"], + p1: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"], + T_batch_matmul_NT: T.Buffer[(T.int64(12), T.int64(197), T.int64(197)), "int32"], + ): + T.func_attr({"tir.noalias": True, "global_symbol": "main", "layout_free_buffers": [1]}) + p1_global = T.alloc_buffer( + [T.int64(2), T.int64(64), T.int64(6), T.int64(197)], dtype="int8" + ) + for ax0, ax1, ax2 in T.grid(T.int64(12), T.int64(197), T.int64(64)): + with T.block("p1_global"): + v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2]) + T.reads(p1[v0, v1, v2]) + T.writes(p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1]) + T.block_attr({"meta_schedule.layout_rewrite_preproc": True}) + p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1] = p1[v0, v1, v2] + for b_0_i_0_fused in T.parallel(T.int64(394)): + for j_0, b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)): + for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid( + T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1) + ): + with T.block("T_batch_matmul_NT_init"): + v_b = T.axis.spatial( + T.int64(12), + b_3_init + + b_0_i_0_fused // T.int64(197) * T.int64(6) + + b_1 * T.int64(6) + + b_2_init, + ) + v_i = T.axis.spatial( + T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2_init + i_3_init + ) + v_j = T.axis.spatial( + T.int64(197), + j_3_init + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2_init, + ) + T_batch_matmul_NT[v_b, v_i, v_j] = 0 + for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid( + T.int64(64), + T.int64(6), + T.int64(1), + T.int64(197), + T.int64(1), + T.int64(1), + T.int64(1), + T.int64(1), + ): + with T.block("T_batch_matmul_NT_update"): + v_b = T.axis.spatial( + T.int64(12), + b_3 + + b_0_i_0_fused // T.int64(197) * T.int64(6) + + b_1 * T.int64(6) + + b_2, + ) + v_i = T.axis.spatial( + T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2 + i_3 + ) + v_j = T.axis.spatial( + T.int64(197), j_3 + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2 + ) + v_k = T.axis.reduce(T.int64(64), k_0 + k_1) + T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[ + v_b, v_i, v_j + ] + T.Cast("int32", p0[v_b, v_i, v_k]) * T.Cast( + "int32", p1_global[v_b // T.int64(6), v_k, v_b % T.int64(6), v_j] + ) + + if __name__ == "__main__": tvm.testing.main()