From 697296ab4a928dd8a1397961f15a59ab82225bd0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 22 Nov 2022 17:33:03 +0900 Subject: [PATCH 1/5] [TIR] Fix buffer shape and IndexMap indices dtype mismatch --- .../primitive/layout_transformation.cc | 30 ++++++++++++- .../test_tir_schedule_transform_layout.py | 45 +++++++++++++++++++ 2 files changed, 74 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index c0b4ddfb4ac3..3cef74bdca77 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1055,13 +1055,41 @@ class TransformationIntroducesPaddingError : public ScheduleError { PrimExpr padding_predicate_; }; +IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { + auto initial_indices_orig = index_map->initial_indices; + ICHECK(buf->shape.size() == initial_indices_orig.size()); + + Array initial_indices; + Map var_map; + + for (size_t i = 0; i < buf->shape.size(); ++i) { + if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) { + auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype); + initial_indices.push_back(new_idx); + var_map.Set(initial_indices_orig[i], new_idx); + } + } + + if (!var_map.empty()) { + auto final_indices = index_map->final_indices.Map([&](PrimExpr index) { + return SubstituteWithDataTypeLegalization(index, + [&](const Var& var) { return var_map.Get(var); }); + }); + return IndexMap(initial_indices, final_indices); + } + return index_map; +} + void TransformLayout(ScheduleState self, const StmtSRef& block_sref, int buffer_index, - BufferIndexType buffer_index_type, const IndexMap& index_map, + BufferIndexType buffer_index_type, const IndexMap& index_map_orig, const Optional& pad_value) { // Step 1: Input handling and error checking const BlockNode* block_ptr = TVM_SREF_TO_BLOCK(block_sref); Buffer old_buffer = GetNthAccessBuffer(self, GetRef(block_ptr), buffer_index, buffer_index_type); + + auto index_map = LegalizeIndexMapDType(index_map_orig, old_buffer); + auto [defining_site_sref, is_alloc] = GetBufferDefiningSite(block_sref, old_buffer); if (defining_site_sref.defined() && !is_alloc) { throw BufferIsSubregionError(self->mod, old_buffer); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index e90478922324..71d6109c3c57 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -173,6 +173,35 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 1 vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 + + +@tvm.script.ir_module +class Conv2dNCHW32c: + @T.prim_func + def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), +T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "int32"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + data_pad = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(58), T.int64(58), T.int64(32)], dtype="uint8") + for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(2), T.int64(58), T.int64(58), T.int64(32)): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(p0[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1), i4_1]) + T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) + data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(57) and T.int64(1) <= i3_1 and i3_1 < T.int64(57), p0[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1), i4_1], T.uint8(0), dtype="uint8") + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32), T.int64(3), T.int64(3), T.int64(2), T.int64(8), T.int64(4)): + with T.block("conv2d_NCHWc_int8"): + n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads(data_pad[n, ic_outer, oh + kh, ow + kw, ic_f_inner * T.int64(4) + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + T.block_attr({"schedule_rule":"conv2d_NCHWc_int8"}) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.Cast("int32", data_pad[n, ic_outer, oh + kh, ow + kw, ic_f_inner * T.int64(4) + ic_s_inner]) * T.Cast("int32", p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) + + # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -925,5 +954,21 @@ def expected(a: T.handle): A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32") +def test_index_map_dtype_legalize(): + """Test dtype legalization of the index map indices.""" + + def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32): + return [n_batch, channel, height // 8, width // 8, height % 8, width % 8, channel_32] + + sch = tir.Schedule(Conv2dNCHW32c, debug_mask="all") + + conv2d_block = sch.get_block("conv2d_NCHWc_int8") + sch.cache_read(conv2d_block, 0, "global.vtcm") + + sch.transform_layout( + conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + ) + + if __name__ == "__main__": tvm.testing.main() From 9b9bac93ee3f73048f1226c706f811eda9dc5dec Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 22 Nov 2022 17:52:34 +0900 Subject: [PATCH 2/5] turn off debug_mask to suppress flaky VerifySRefTree error --- tests/python/unittest/test_tir_schedule_transform_layout.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 71d6109c3c57..9b89787cd3a7 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -960,7 +960,7 @@ def test_index_map_dtype_legalize(): def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32): return [n_batch, channel, height // 8, width // 8, height % 8, width % 8, channel_32] - sch = tir.Schedule(Conv2dNCHW32c, debug_mask="all") + sch = tir.Schedule(Conv2dNCHW32c) conv2d_block = sch.get_block("conv2d_NCHWc_int8") sch.cache_read(conv2d_block, 0, "global.vtcm") From b37bbb96fb44870668333ba0c8db97a22ba8516d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 22 Nov 2022 18:38:36 +0900 Subject: [PATCH 3/5] add comment --- src/tir/schedule/primitive/layout_transformation.cc | 2 ++ tests/python/unittest/test_tir_schedule_transform_layout.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 3cef74bdca77..776dd685238d 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1055,6 +1055,8 @@ class TransformationIntroducesPaddingError : public ScheduleError { PrimExpr padding_predicate_; }; +// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid +// dtype-mismatch issues later. IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { auto initial_indices_orig = index_map->initial_indices; ICHECK(buf->shape.size() == initial_indices_orig.size()); diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 9b89787cd3a7..2f20922721a7 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -965,6 +965,9 @@ def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32): conv2d_block = sch.get_block("conv2d_NCHWc_int8") sch.cache_read(conv2d_block, 0, "global.vtcm") + # The following error is raised from the IterVar constructor without the dtype legalization. + # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : + # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) sch.transform_layout( conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 ) From 21d550710af89446cce3dbf3fff63466cdd5cbea Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Nov 2022 06:01:03 +0900 Subject: [PATCH 4/5] add missing const auto&, handle cases dtypes partially match --- src/tir/schedule/primitive/layout_transformation.cc | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/tir/schedule/primitive/layout_transformation.cc b/src/tir/schedule/primitive/layout_transformation.cc index 776dd685238d..bf618af8de54 100644 --- a/src/tir/schedule/primitive/layout_transformation.cc +++ b/src/tir/schedule/primitive/layout_transformation.cc @@ -1058,7 +1058,7 @@ class TransformationIntroducesPaddingError : public ScheduleError { // Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid // dtype-mismatch issues later. IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { - auto initial_indices_orig = index_map->initial_indices; + const auto& initial_indices_orig = index_map->initial_indices; ICHECK(buf->shape.size() == initial_indices_orig.size()); Array initial_indices; @@ -1069,6 +1069,8 @@ IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) { auto new_idx = Var(initial_indices_orig[i]->name_hint, buf->shape[i]->dtype); initial_indices.push_back(new_idx); var_map.Set(initial_indices_orig[i], new_idx); + } else { + initial_indices.push_back(initial_indices_orig[i]); } } From aa7f08ef973b1ce2924d402fa38ce7afb01c7387 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 23 Nov 2022 06:02:17 +0900 Subject: [PATCH 5/5] massively simplify test case --- .../test_tir_schedule_transform_layout.py | 51 +++++-------------- 1 file changed, 12 insertions(+), 39 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_transform_layout.py b/tests/python/unittest/test_tir_schedule_transform_layout.py index 2f20922721a7..35b2dd53b80e 100644 --- a/tests/python/unittest/test_tir_schedule_transform_layout.py +++ b/tests/python/unittest/test_tir_schedule_transform_layout.py @@ -173,35 +173,6 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 1 vi, vj = T.axis.remap("SS", [i, j]) C[vi, vj] = B[vi, vj] + 1.0 - - -@tvm.script.ir_module -class Conv2dNCHW32c: - @T.prim_func - def main(p0: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "uint8"], p1: T.Buffer[(T.int64(2), T.int64(2), T.int64(3), -T.int64(3), T.int64(8), T.int64(32), T.int64(4)), "uint8"], conv2d_NCHWc_int8: T.Buffer[(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32)), "int32"]): - # function attr dict - T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body - # with T.block("root") - data_pad = T.alloc_buffer([T.int64(1), T.int64(2), T.int64(58), T.int64(58), T.int64(32)], dtype="uint8") - for i0, i1, i2, i3, i4 in T.grid(T.int64(1), T.int64(2), T.int64(58), T.int64(58), T.int64(32)): - with T.block("data_pad"): - i0_1, i1_1, i2_1, i3_1, i4_1 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) - T.reads(p0[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1), i4_1]) - T.writes(data_pad[i0_1, i1_1, i2_1, i3_1, i4_1]) - data_pad[i0_1, i1_1, i2_1, i3_1, i4_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(57) and T.int64(1) <= i3_1 and i3_1 < T.int64(57), p0[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1), i4_1], T.uint8(0), dtype="uint8") - for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(T.int64(1), T.int64(2), T.int64(56), T.int64(56), T.int64(32), T.int64(3), T.int64(3), T.int64(2), T.int64(8), T.int64(4)): - with T.block("conv2d_NCHWc_int8"): - n, oc_chunk, oh, ow, oc_block, kh, kw, ic_outer, ic_f_inner, ic_s_inner = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) - T.reads(data_pad[n, ic_outer, oh + kh, ow + kw, ic_f_inner * T.int64(4) + ic_s_inner], p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) - T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) - T.block_attr({"schedule_rule":"conv2d_NCHWc_int8"}) - with T.init(): - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 - conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] + T.Cast("int32", data_pad[n, ic_outer, oh + kh, ow + kw, ic_f_inner * T.int64(4) + ic_s_inner]) * T.Cast("int32", p1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner]) - - # pylint: enable=no-member,invalid-name,unused-variable,line-too-long,redefined-outer-name,unexpected-keyword-arg,too-many-nested-blocks # fmt: on @@ -957,19 +928,21 @@ def expected(a: T.handle): def test_index_map_dtype_legalize(): """Test dtype legalization of the index map indices.""" - def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32): - return [n_batch, channel, height // 8, width // 8, height % 8, width % 8, channel_32] - - sch = tir.Schedule(Conv2dNCHW32c) + @T.prim_func + def func(A: T.Buffer[T.int64(58), "int32"]): + for i in T.serial(T.int64(58)): + with T.block("block"): + vi = T.axis.remap("S", [i]) + T.writes(A[vi]) + A[vi] = 0 - conv2d_block = sch.get_block("conv2d_NCHWc_int8") - sch.cache_read(conv2d_block, 0, "global.vtcm") + sch = tir.Schedule(func) - # The following error is raised from the IterVar constructor without the dtype legalization. - # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : - # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) + # # The following error is raised from the IterVar constructor without the dtype legalization. + # # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) : + # # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32) sch.transform_layout( - conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0 + sch.get_block("block"), buffer="A", index_map=lambda h: [h // 8, h % 8], pad_value=0 )