Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions src/tir/schedule/primitive/layout_transformation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1056,16 +1056,16 @@ class TransformationIntroducesPaddingError : public ScheduleError {

// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
// dtype-mismatch issues later.
IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Array<PrimExpr>& args) {
const auto& initial_indices_orig = index_map->initial_indices;
ICHECK(buf->shape.size() == initial_indices_orig.size());
ICHECK(args.size() == initial_indices_orig.size());

Array<Var> initial_indices;
Map<Var, PrimExpr> var_map;

for (size_t i = 0; i < buf->shape.size(); ++i) {
if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) {
auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype);
for (size_t i = 0; i < args.size(); ++i) {
if (args[i]->dtype != initial_indices_orig[i].dtype()) {
auto new_idx = Var(initial_indices_orig[i]->name_hint, args[i]->dtype);
initial_indices.push_back(new_idx);
var_map.Set(initial_indices_orig[i], new_idx);
} else {
Expand All @@ -1078,7 +1078,12 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
return SubstituteWithDataTypeLegalization(index,
[&](const Var& var) { return var_map.Get(var); });
});
return IndexMap(initial_indices, final_indices);
Optional<IndexMap> opt_inverse_index_map =
Downcast<Optional<IndexMap>>(index_map->inverse_index_map);
if (opt_inverse_index_map.defined()) {
opt_inverse_index_map = LegalizeIndexMapDType(opt_inverse_index_map.value(), final_indices);
}
return IndexMap(initial_indices, final_indices, opt_inverse_index_map);
}
return index_map;
}
Expand All @@ -1091,7 +1096,7 @@ void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_
Buffer old_buffer =
GetNthAccessBuffer(self, GetRef<Block>(block_ptr), buffer_index, buffer_index_type);

auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer);
auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer->shape);

auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer);
if (defining_site_sref.defined() && !is_alloc) {
Expand Down
130 changes: 130 additions & 0 deletions tests/python/unittest/test_meta_schedule_postproc_rewrite_layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,5 +480,135 @@ def test_layout_rewrite_cache_read_multiple():
tvm.ir.assert_structural_equal(sch.mod, Conv2dCacheReadMultipleRewritten)


class TestLayoutRewriteInt64Index(BaseBeforeAfter):
def before(
p0: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"],
p1: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"],
T_batch_matmul_NT: T.Buffer[(T.int64(12), T.int64(197), T.int64(197)), "int32"],
):
T.func_attr({"layout_free_buffers": [1], "global_symbol": "main", "tir.noalias": True})
for b_0_i_0_fused in T.parallel(T.int64(394)):
for j_0 in T.serial(T.int64(1)):
for b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1)):
for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid(
T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_batch_matmul_NT_init"):
v_b = T.axis.spatial(
T.int64(12),
b_3_init
+ b_0_i_0_fused // T.int64(197) * T.int64(6)
+ b_1 * T.int64(6)
+ b_2_init,
)
v_i = T.axis.spatial(
T.int64(197),
b_0_i_0_fused % T.int64(197) + i_1 + i_2_init + i_3_init,
)
v_j = T.axis.spatial(
T.int64(197),
j_3_init + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2_init,
)
T_batch_matmul_NT[v_b, v_i, v_j] = 0
for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(
T.int64(64),
T.int64(6),
T.int64(1),
T.int64(197),
T.int64(1),
T.int64(1),
T.int64(1),
T.int64(1),
):
with T.block("T_batch_matmul_NT_update"):
v_b = T.axis.spatial(
T.int64(12),
b_3
+ b_0_i_0_fused // T.int64(197) * T.int64(6)
+ b_1 * T.int64(6)
+ b_2,
)
v_i = T.axis.spatial(
T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2 + i_3
)
v_j = T.axis.spatial(
T.int64(197), j_3 + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2
)
v_k = T.axis.reduce(T.int64(64), k_0 + k_1)
T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[
v_b, v_i, v_j
] + T.Cast("int32", p0[v_b, v_i, v_k]) * T.Cast(
"int32", p1[v_b, v_j, v_k]
)

def expected(
p0: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"],
p1: T.Buffer[(T.int64(12), T.int64(197), T.int64(64)), "int8"],
T_batch_matmul_NT: T.Buffer[(T.int64(12), T.int64(197), T.int64(197)), "int32"],
):
T.func_attr({"tir.noalias": True, "global_symbol": "main", "layout_free_buffers": [1]})
p1_global = T.alloc_buffer(
[T.int64(2), T.int64(64), T.int64(6), T.int64(197)], dtype="int8"
)
for ax0, ax1, ax2 in T.grid(T.int64(12), T.int64(197), T.int64(64)):
with T.block("p1_global"):
v0, v1, v2 = T.axis.remap("SSS", [ax0, ax1, ax2])
T.reads(p1[v0, v1, v2])
T.writes(p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1])
T.block_attr({"meta_schedule.layout_rewrite_preproc": True})
p1_global[v0 // T.int64(6), v2, v0 % T.int64(6), v1] = p1[v0, v1, v2]
for b_0_i_0_fused in T.parallel(T.int64(394)):
for j_0, b_1, i_1, j_1 in T.grid(T.int64(1), T.int64(1), T.int64(1), T.int64(1)):
for b_2_init, i_2_init, j_2_init, b_3_init, i_3_init, j_3_init in T.grid(
T.int64(6), T.int64(1), T.int64(197), T.int64(1), T.int64(1), T.int64(1)
):
with T.block("T_batch_matmul_NT_init"):
v_b = T.axis.spatial(
T.int64(12),
b_3_init
+ b_0_i_0_fused // T.int64(197) * T.int64(6)
+ b_1 * T.int64(6)
+ b_2_init,
)
v_i = T.axis.spatial(
T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2_init + i_3_init
)
v_j = T.axis.spatial(
T.int64(197),
j_3_init + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2_init,
)
T_batch_matmul_NT[v_b, v_i, v_j] = 0
for k_0, b_2, i_2, j_2, k_1, b_3, i_3, j_3 in T.grid(
T.int64(64),
T.int64(6),
T.int64(1),
T.int64(197),
T.int64(1),
T.int64(1),
T.int64(1),
T.int64(1),
):
with T.block("T_batch_matmul_NT_update"):
v_b = T.axis.spatial(
T.int64(12),
b_3
+ b_0_i_0_fused // T.int64(197) * T.int64(6)
+ b_1 * T.int64(6)
+ b_2,
)
v_i = T.axis.spatial(
T.int64(197), b_0_i_0_fused % T.int64(197) + i_1 + i_2 + i_3
)
v_j = T.axis.spatial(
T.int64(197), j_3 + j_0 * T.int64(197) + j_1 * T.int64(197) + j_2
)
v_k = T.axis.reduce(T.int64(64), k_0 + k_1)
T_batch_matmul_NT[v_b, v_i, v_j] = T_batch_matmul_NT[
v_b, v_i, v_j
] + T.Cast("int32", p0[v_b, v_i, v_k]) * T.Cast(
"int32", p1_global[v_b // T.int64(6), v_k, v_b % T.int64(6), v_j]
)


if __name__ == "__main__":
tvm.testing.main()