From 062d8c2ebb468fcc97fdfccd98470c40fc923565 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 07:22:02 +0900 Subject: [PATCH 01/19] [TIR] Add TensorizeInfo and GetTensorizeLoopMapping --- src/tir/schedule/analysis.h | 26 +++++++ src/tir/schedule/analysis/analysis.cc | 102 ++++++++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index b76d41326ff1..cee45e1a4398 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -656,6 +656,32 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P const StmtSRef& dom_high_exclusive, arith::Analyzer* analyzer); +/*! \brief Necessary information used for tensorization */ +class TensorizeInfoNode : public Object { + public: + /*! \brief Maps block loops to desc loops */ + Map loop_map; + /*! \brief Maps loops in desc to its index, outer to inner */ + Map desc_loop_indexer; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("loop_map", &loop_map); + v->Visit("desc_loop_indexer", &desc_loop_indexer); + } + + static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); +}; + +class TensorizeInfo : public ObjectRef { + public: + TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); +}; + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func); + } // namespace tir } // namespace tvm diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4a7ac401dd60..8245e7b11cb1 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2028,5 +2028,107 @@ bool NeedsRFactorOrCrossThreadReduction(const tir::ScheduleState& self, // } } +TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); + +Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, + const tir::StmtSRef& block_sref, + const tir::PrimFunc& desc_func) { + // Try to do tiling automatically if possible + // Now the heuristic is that if block's block var binding is constant + loop var, + // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder + // i, j, k according to the loops outside desc_block + // Collect the loops outside block + arith::Analyzer analyzer; + const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); + // Step 1. Analyze desc_func, extract its block, loops and loop vars + const tir::BlockRealizeNode* desc_block = nullptr; + std::vector desc_loops; + std::unordered_set desc_loop_vars; + const auto* desc_scope_realize = desc_func->body.as(); + ICHECK(desc_scope_realize); + { + auto f_visit = [&desc_block, &desc_loops, &desc_loop_vars, + &analyzer](const ObjectRef& obj) -> bool { + // Extract the block + if (const auto* block = obj.as()) { + desc_block = block; + return false; + } + // Extract the loops + if (const auto* loop = obj.as()) { + desc_loops.push_back(loop); + desc_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return false; + } + } + return true; + }; + tir::PostOrderVisit(desc_scope_realize->block->body, f_visit); + std::reverse(desc_loops.begin(), desc_loops.end()); + ICHECK(desc_block); + } + // Step 2. Check if `desc_block` matches `block` + // Ignore the scope of buffers when comparing, since we can do cache_read/write + const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); + const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); + std::vector block_loops; + std::unordered_set block_loop_vars; + { + for (const tir::StmtSRefNode* loop_sref = block_sref->parent;; loop_sref = loop_sref->parent) { + const auto* loop = loop_sref->StmtAs(); + if (loop == nullptr || loop->body->IsInstance()) { + break; + } + block_loops.push_back(loop); + block_loop_vars.insert(loop->loop_var.get()); + if (!analyzer.CanProve(loop->min == 0)) { + return NullOpt; + } + } + std::reverse(block_loops.begin(), block_loops.end()); + } + // Step 4. Map from block loops to desc block loops + ObjectPtr ret = make_object(); + int n_block_vars = block->iter_values.size(); + int n_desc_vars = desc_block->iter_values.size(); + int offset = n_block_vars - n_desc_vars; + if (offset < 0) { + return NullOpt; + } + + std::vector iter_types_desc; + for (const IterVar& iter_var : desc_block->block->iter_vars) { + iter_types_desc.push_back(iter_var->iter_type); + } + + std::vector iter_types_block = GetBlockVarTypes(block_sref); + ICHECK(block_loops.size() == iter_types_block.size()); + + int next_block_ind = block_loops.size() - 1; + for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { + const tir::ForNode* desc_loop = desc_loops[i_desc]; + const auto* int_desc_extent = desc_loop->extent.as(); + + for (int i_block = next_block_ind; i_block >= 0; --i_block) { + const tir::ForNode* block_loop = block_loops[i_block]; + const auto* int_block_extent = block_loop->extent.as(); + + if (int_block_extent->value % int_desc_extent->value != 0) continue; + if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue; + + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + next_block_ind = i_block - 1; + break; + } + } + + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + ret->desc_loop_indexer.Set(GetRef(desc_loops[i]), Integer(i)); + } + return TensorizeInfo(ret); +} + } // namespace tir } // namespace tvm From e0c3337c8a7050d63c956e31e473b283aac48898 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 17:17:06 +0900 Subject: [PATCH 02/19] expose PreOrderVisit to python --- python/tvm/tir/stmt_functor.py | 12 ++++++++++++ src/tir/ir/stmt_functor.cc | 4 ++++ 2 files changed, 16 insertions(+) diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 56dc1c20c2b3..7c66c5ad85ce 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -58,6 +58,18 @@ def post_order_visit(stmt, fvisit): return _ffi_api.PostOrderVisit(stmt, fvisit) # type: ignore +def pre_order_visit(stmt, fvisit): + """Recursive pre-order visit on stmt AST, applying fvisit on each node. + If fvisit returns false, it won't visit the children of the node. + + Parameters + ---------- + fvisit: function + The visitor function. + """ + return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore + + def substitute(node, vmap): """Substitute the var specified by vmap. diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index c4d7ad0f6c67..06933c2c0dcb 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -792,6 +792,10 @@ TVM_REGISTER_GLOBAL("tir.PostOrderVisit").set_body_typed([](ObjectRef node, Pack tir::PostOrderVisit(node, [f](const ObjectRef& n) { f(n); }); }); +TVM_REGISTER_GLOBAL("tir.PreOrderVisit").set_body_typed([](ObjectRef node, PackedFunc f) { + tir::PreOrderVisit(node, [f](const ObjectRef& n) { return f(n); }); +}); + TVM_REGISTER_GLOBAL("tir.Substitute") .set_body_typed([](ObjectRef node, Map vmap) -> ObjectRef { if (node->IsInstance()) { From 51df94dc2b2e710095bb419a51c4409856621775 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 17:17:33 +0900 Subject: [PATCH 03/19] add test case --- python/tvm/tir/schedule/analysis.py | 5 ++ src/tir/schedule/analysis/analysis.cc | 5 ++ .../unittest/test_tir_schedule_analysis.py | 79 ++++++++++++++++--- 3 files changed, 76 insertions(+), 13 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index f2fb7c4f3d1d..35b1633a20a6 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -56,3 +56,8 @@ def suggest_index_map( loops, predicate, ) + + +def get_tensorize_loop_mapping(state, block_sref, desc_func): + """TODO""" + return _ffi_api.GetTensorizeLoopMapping(state, block_sref, desc_func) # type: ignore diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 8245e7b11cb1..96598cbc1943 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2130,5 +2130,10 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, return TensorizeInfo(ret); } +TVM_REGISTER_GLOBAL("tir.schedule.GetTensorizeLoopMapping") + .set_body_typed([](Schedule sch, BlockRV block, PrimFunc desc_func) { + return GetTensorizeLoopMapping(sch->state(), sch->GetSRef(block), desc_func); + }); + } // namespace tir } // namespace tvm diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 760b412ac804..aed14468060b 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -17,18 +17,15 @@ # pylint: disable=missing-docstring from typing import List -from tvm.tir import ( - Evaluate, - For, - ForKind, - IndexMap, - Var, - decl_buffer, - floordiv, - floormod, -) +import tvm +from tvm.tir.tensor_intrin.x86 import dot_product_16x4_u8i8i32_desc + + +from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map +from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping +from tvm.script import tir as T +from tvm.tir.stmt_functor import pre_order_visit def _make_vars(*args: str) -> List[Var]: @@ -102,6 +99,62 @@ def test_suggest_index_map_bijective(): _assert_equal_index_map(index_map, expected_index_map) +@tvm.script.ir_module +class DenseVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1024, 1024), "uint8"], + placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], + compute: T.Buffer[(1024, 1024), "int32"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + for i0, i1, i2 in T.grid(1024, 1024, 1024): + with T.block("compute"): + i, j, k = T.axis.remap("SSR", [i0, i1, i2]) + T.reads(placeholder[i, k], placeholder_1[j // 16, k // 4, j % 16, k % 4]) + T.writes(compute[i, j]) + with T.init(): + compute[i, j] = 0 + compute[i, j] = compute[i, j] + T.cast(placeholder[i, k], "int32") * T.cast( + placeholder_1[j // 16, k // 4, j % 16, k % 4], "int32" + ) + + +def collect_loops(prim_func): + loops = [] + + def callback(node): + if isinstance(node, tvm.tir.For): + loops.append(node) + return True + + pre_order_visit(prim_func.body, callback) + + return loops + + +def test_get_tensorize_loop_mapping(): + s = Schedule(DenseVNNIModule) + block = s.get_block("compute") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + _, loop_j, loop_k = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(loop_j) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k) + + if __name__ == "__main__": - test_suggest_index_map_simple() - test_suggest_index_map_bijective() + # test_suggest_index_map_simple() + # test_suggest_index_map_bijective() + test_get_tensorize_loop_mapping() From 84801b6482e9fd00969285a8f0c6c413d12cda3b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 17:34:51 +0900 Subject: [PATCH 04/19] add conv2d nchwc test --- .../unittest/test_tir_schedule_analysis.py | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index aed14468060b..2b2f40b21997 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -107,9 +107,7 @@ def main( placeholder_1: T.Buffer[(64, 256, 16, 4), "int8"], compute: T.Buffer[(1024, 1024), "int32"], ) -> None: - # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - # body with T.block("root"): T.reads() T.writes() @@ -125,6 +123,46 @@ def main( ) +@tvm.script.ir_module +class Conv2dNCHWcVNNIModule: + @T.prim_func + def main( + placeholder: T.Buffer[(1, 4, 56, 56, 16), "uint8"], + placeholder_1: T.Buffer[(16, 4, 1, 1, 4, 16, 4), "int8"], + conv2d_NCHWc_int8: T.Buffer[(1, 16, 56, 56, 16), "int32"], + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): + with T.block("conv2d_NCHWc_int8"): + ( + n, + oc_chunk, + oh, + ow, + oc_block, + kh, + kw, + ic_outer, + ic_f_inner, + ic_s_inner, + ) = T.axis.remap("SSSSSRRRRR", [i0, i1, i2, i3, i4, i5, i6, i7, i8, i9]) + T.reads( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + ) + T.writes(conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block]) + with T.init(): + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = 0 + conv2d_NCHWc_int8[n, oc_chunk, oh, ow, oc_block] = conv2d_NCHWc_int8[ + n, oc_chunk, oh, ow, oc_block + ] + T.cast( + placeholder[n, ic_outer, oh + kh, ow + kw, ic_f_inner * 4 + ic_s_inner], "int32" + ) * T.cast( + placeholder_1[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner], + "int32", + ) + + def collect_loops(prim_func): loops = [] @@ -138,7 +176,7 @@ def callback(node): return loops -def test_get_tensorize_loop_mapping(): +def test_get_tensorize_loop_mapping_dense_vnni(): s = Schedule(DenseVNNIModule) block = s.get_block("compute") @@ -154,7 +192,24 @@ def test_get_tensorize_loop_mapping(): assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(loop_k) +def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): + s = Schedule(Conv2dNCHWcVNNIModule) + block = s.get_block("conv2d_NCHWc_int8") + + info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) + + assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i4) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9) + + if __name__ == "__main__": # test_suggest_index_map_simple() # test_suggest_index_map_bijective() - test_get_tensorize_loop_mapping() + # test_get_tensorize_loop_mapping_dense_vnni() + test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() From fcd7917955750c689a14901dc57dca8596fbd253 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 17:58:38 +0900 Subject: [PATCH 05/19] add mma test --- .../unittest/test_tir_schedule_analysis.py | 46 ++++++++++++++++++- 1 file changed, 45 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 2b2f40b21997..d36476d78e85 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -26,6 +26,8 @@ from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit +from tvm.meta_schedule.testing import te_workload +from tvm.te import create_prim_func def _make_vars(*args: str) -> List[Var]: @@ -208,8 +210,50 @@ def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i9) +def test_get_tensorize_loop_mapping_matmul_mma(): + @T.prim_func + def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) + B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) + C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) + + with T.block("root"): + T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) + T.writes(C[0:16, 0:16]) + for i, j, k in T.grid(16, 16, 16): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + matmul = create_prim_func( + te_workload.matmul_relu( + n=512, + m=512, + k=512, + ) + ) + + s = Schedule(matmul) + block = s.get_block("C") + + info = get_tensorize_loop_mapping(s, block, mma_desc) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(mma_desc) + i0, i1, i2 = s.get_loops(block) + + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref + + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) + + if __name__ == "__main__": # test_suggest_index_map_simple() # test_suggest_index_map_bijective() # test_get_tensorize_loop_mapping_dense_vnni() - test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() + # test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() + test_get_tensorize_loop_mapping_matmul_mma() From 65682c29972b1b79c26046b6585aa70c510bf1db Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 18:36:54 +0900 Subject: [PATCH 06/19] add arm nhwc conv2d test --- .../unittest/test_tir_schedule_analysis.py | 88 ++++++++++++++++++- 1 file changed, 87 insertions(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index d36476d78e85..9c1500065af0 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -27,7 +27,11 @@ from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.te import create_prim_func +from tvm import relay +import numpy as np +from tvm.meta_schedule.tune import Parse def _make_vars(*args: str) -> List[Var]: @@ -251,9 +255,91 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) +def test_get_tensorize_loop_mapping_conv2d_nhwc_arm(): + @T.prim_func + def gemm_4x4x4_i8i8i32(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (4, 4), offset_factor=1, dtype="int8") + B = T.match_buffer(b, (4, 4), offset_factor=1, dtype="int8") + C = T.match_buffer(c, (4, 4), offset_factor=1, dtype="int8") + + with T.block("root"): + T.reads(C[0:4, 0:4], A[0:4, 0:4], B[0:4, 0:4]) + T.writes(C[0:4, 0:4]) + for i, j, k in T.grid(4, 4, 4): + with T.block("update"): + vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) + C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] + + data_shape = (8, 64, 56, 56) + weight_shape = (64, 64, 3, 3) + + data_dtype = "int8" + weight_dtype = "int8" + out_dtype = "int32" + + data = relay.var("data", shape=data_shape, dtype=data_dtype) + weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) + out_channel = weight_shape[0] + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=weight_shape[2:], + channels=out_channel, + padding=(1, 1), + strides=(1, 1), + out_dtype=out_dtype, + ) + + relay_mod = tvm.IRModule.from_expr(conv2d) + + data = np.random.randint(low=-127, high=128, size=data_shape).astype("int8") + weight_np = np.random.randint(low=-127, high=128, size=weight_shape).astype("int8") + + def convert_conv2d_layout(mod, desired_layouts): + with tvm.transform.PassContext(opt_level=3): + seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) + return seq(mod) + + relay_mod = convert_conv2d_layout(relay_mod, {"nn.conv2d": ["NHWC", "HWIO"]}) + + params = {"weight": weight_np} + + target = "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod" + extracted_tasks = extract_task_from_relay(relay_mod, target, params) + + conv2d_tasks = list( + filter( + lambda task: "conv2d" in task.task_name, + extracted_tasks, + ) + ) + + mod = Parse._mod(conv2d_tasks[0].dispatched[0]) + + s = tvm.tir.Schedule(mod) + + block = s.get_block("C") + + info = get_tensorize_loop_mapping(s, block, gemm_4x4x4_i8i8i32) + + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + + desc_loops = collect_loops(gemm_4x4x4_i8i8i32) + + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref + + _, i1_5, i2_4, i3_3 = s.get_loops(block) + + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i1_5) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i2_4) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i3_3) + + if __name__ == "__main__": # test_suggest_index_map_simple() # test_suggest_index_map_bijective() # test_get_tensorize_loop_mapping_dense_vnni() # test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() - test_get_tensorize_loop_mapping_matmul_mma() + # test_get_tensorize_loop_mapping_matmul_mma() + test_get_tensorize_loop_mapping_conv2d_nhwc_arm() From fcca9fba4308b42675d62451cd4d83ba5b8df70f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 19:11:58 +0900 Subject: [PATCH 07/19] Revert "add arm nhwc conv2d test" This reverts commit eb147f33bb02d62a0eacc9cdfe777ac047ee1bc9. --- .../unittest/test_tir_schedule_analysis.py | 88 +------------------ 1 file changed, 1 insertion(+), 87 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 9c1500065af0..d36476d78e85 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -27,11 +27,7 @@ from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload -from tvm.meta_schedule.relay_integration import extract_task_from_relay from tvm.te import create_prim_func -from tvm import relay -import numpy as np -from tvm.meta_schedule.tune import Parse def _make_vars(*args: str) -> List[Var]: @@ -255,91 +251,9 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) -def test_get_tensorize_loop_mapping_conv2d_nhwc_arm(): - @T.prim_func - def gemm_4x4x4_i8i8i32(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4, 4), offset_factor=1, dtype="int8") - B = T.match_buffer(b, (4, 4), offset_factor=1, dtype="int8") - C = T.match_buffer(c, (4, 4), offset_factor=1, dtype="int8") - - with T.block("root"): - T.reads(C[0:4, 0:4], A[0:4, 0:4], B[0:4, 0:4]) - T.writes(C[0:4, 0:4]) - for i, j, k in T.grid(4, 4, 4): - with T.block("update"): - vii, vjj, vkk = T.axis.remap("SSR", [i, j, k]) - C[vii, vjj] = C[vii, vjj] + A[vii, vkk] * B[vjj, vkk] - - data_shape = (8, 64, 56, 56) - weight_shape = (64, 64, 3, 3) - - data_dtype = "int8" - weight_dtype = "int8" - out_dtype = "int32" - - data = relay.var("data", shape=data_shape, dtype=data_dtype) - weight = relay.var("weight", shape=weight_shape, dtype=weight_dtype) - out_channel = weight_shape[0] - conv2d = relay.nn.conv2d( - data=data, - weight=weight, - kernel_size=weight_shape[2:], - channels=out_channel, - padding=(1, 1), - strides=(1, 1), - out_dtype=out_dtype, - ) - - relay_mod = tvm.IRModule.from_expr(conv2d) - - data = np.random.randint(low=-127, high=128, size=data_shape).astype("int8") - weight_np = np.random.randint(low=-127, high=128, size=weight_shape).astype("int8") - - def convert_conv2d_layout(mod, desired_layouts): - with tvm.transform.PassContext(opt_level=3): - seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)]) - return seq(mod) - - relay_mod = convert_conv2d_layout(relay_mod, {"nn.conv2d": ["NHWC", "HWIO"]}) - - params = {"weight": weight_np} - - target = "llvm --device arm_cpu --mtriple aarch64-linux-gnu -mattr=+v8.2a,+dotprod" - extracted_tasks = extract_task_from_relay(relay_mod, target, params) - - conv2d_tasks = list( - filter( - lambda task: "conv2d" in task.task_name, - extracted_tasks, - ) - ) - - mod = Parse._mod(conv2d_tasks[0].dispatched[0]) - - s = tvm.tir.Schedule(mod) - - block = s.get_block("C") - - info = get_tensorize_loop_mapping(s, block, gemm_4x4x4_i8i8i32) - - desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) - - desc_loops = collect_loops(gemm_4x4x4_i8i8i32) - - for i in range(3): - assert desc_loops[i] in desc_loop_to_sref - - _, i1_5, i2_4, i3_3 = s.get_loops(block) - - assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i1_5) - assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i2_4) - assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i3_3) - - if __name__ == "__main__": # test_suggest_index_map_simple() # test_suggest_index_map_bijective() # test_get_tensorize_loop_mapping_dense_vnni() # test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() - # test_get_tensorize_loop_mapping_matmul_mma() - test_get_tensorize_loop_mapping_conv2d_nhwc_arm() + test_get_tensorize_loop_mapping_matmul_mma() From 4eb584532d73ae207eba6a6dfabe24410a0319bd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 19:18:43 +0900 Subject: [PATCH 08/19] refine --- .../unittest/test_tir_schedule_analysis.py | 25 +++++++++++-------- 1 file changed, 14 insertions(+), 11 deletions(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index d36476d78e85..739974252ebc 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -203,6 +203,9 @@ def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) + + # i4 corresonds to the inner output channel axis of the NCHWc output tensor + # for i0, i1, i2, i3, i4, i5, i6, i7, i8, i9 in T.grid(1, 16, 56, 56, 16, 1, 1, 4, 4, 4): _, _, _, _, i4, _, _, _, _, i9 = s.get_loops(block) assert desc_loops[0] in desc_loop_to_sref and desc_loops[1] in desc_loop_to_sref @@ -212,11 +215,11 @@ def test_get_tensorize_loop_mapping_conv2d_nchwc_vnni(): def test_get_tensorize_loop_mapping_matmul_mma(): @T.prim_func - def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), align=128, offset_factor=1) - B = T.match_buffer(b, (16, 16), align=128, offset_factor=1) - C = T.match_buffer(c, (16, 16), align=128, offset_factor=1) - + def matmul_16x16x16xf16f16f16_desc( + A: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + B: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + C: T.Buffer((16, 16), "float16", align=128, offset_factor=1), + ) -> None: with T.block("root"): T.reads(C[0:16, 0:16], A[0:16, 0:16], B[0:16, 0:16]) T.writes(C[0:16, 0:16]) @@ -236,11 +239,11 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: s = Schedule(matmul) block = s.get_block("C") - info = get_tensorize_loop_mapping(s, block, mma_desc) + info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) - desc_loops = collect_loops(mma_desc) + desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) i0, i1, i2 = s.get_loops(block) for i in range(3): @@ -252,8 +255,8 @@ def mma_desc(a: T.handle, b: T.handle, c: T.handle) -> None: if __name__ == "__main__": - # test_suggest_index_map_simple() - # test_suggest_index_map_bijective() - # test_get_tensorize_loop_mapping_dense_vnni() - # test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() + test_suggest_index_map_simple() + test_suggest_index_map_bijective() + test_get_tensorize_loop_mapping_dense_vnni() + test_get_tensorize_loop_mapping_conv2d_nchwc_vnni() test_get_tensorize_loop_mapping_matmul_mma() From f759f43b010bcc37867066c6a7ced3999a920b33 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 20:01:10 +0900 Subject: [PATCH 09/19] add doc --- python/tvm/tir/schedule/analysis.py | 33 ++++++++++++++++--- python/tvm/tir/stmt_functor.py | 4 +-- src/tir/schedule/analysis.h | 13 ++++++-- .../unittest/test_tir_schedule_analysis.py | 4 ++- 4 files changed, 44 insertions(+), 10 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 35b1633a20a6..817792ccd179 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -20,9 +20,13 @@ from ..buffer import Buffer from ..stmt import For from ..expr import PrimExpr -from ..function import IndexMap +from ..function import IndexMap, PrimFunc from . import _ffi_api +from .schedule import Schedule, BlockRV + +import tvm._ffi +from tvm.runtime import Object def suggest_index_map( @@ -58,6 +62,27 @@ def suggest_index_map( ) -def get_tensorize_loop_mapping(state, block_sref, desc_func): - """TODO""" - return _ffi_api.GetTensorizeLoopMapping(state, block_sref, desc_func) # type: ignore +@tvm._ffi.register_object("tir.schedule.TensorizeInfo") +class TensorizeInfo(Object): + """Necessary information used for tensorization.""" + pass + + +def get_tensorize_loop_mapping(sch: Schedule, block: BlockRV, desc_func: PrimFunc) -> TensorizeInfo: + """Establish a mapping between loops in a target block and an intrinsic description + + Parameters + ---------- + sch : Schedule + The schedule to be tensorized + block : BlockRV + The target block to match against + desc_func : PrimFunc + The prim func describing the computation to be tensorized + + Returns + ------- + tensorize_info : Optional[TensorizeInfo] + TensorizeInfo structure if a valid mapping is found, None otherwise + """ + return _ffi_api.GetTensorizeLoopMapping(sch, block, desc_func) # type: ignore diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 7c66c5ad85ce..5bcf4ae802c7 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -60,11 +60,11 @@ def post_order_visit(stmt, fvisit): def pre_order_visit(stmt, fvisit): """Recursive pre-order visit on stmt AST, applying fvisit on each node. - If fvisit returns false, it won't visit the children of the node. + If fvisit returns False, it won't visit the children of the node. Parameters ---------- - fvisit: function + fvisit: function of the signature Object -> bool The visitor function. """ return _ffi_api.PreOrderVisit(stmt, fvisit) # type: ignore diff --git a/src/tir/schedule/analysis.h b/src/tir/schedule/analysis.h index cee45e1a4398..c9c3d72ae0b5 100644 --- a/src/tir/schedule/analysis.h +++ b/src/tir/schedule/analysis.h @@ -659,9 +659,9 @@ Array AnalyzeRegionLowerBound(const BufferRegion& region, const P /*! \brief Necessary information used for tensorization */ class TensorizeInfoNode : public Object { public: - /*! \brief Maps block loops to desc loops */ + /*! \brief Maps loops in a target block to the ones in an intrinsic description */ Map loop_map; - /*! \brief Maps loops in desc to its index, outer to inner */ + /*! \brief Maps loops in an intrinsic description to its index, outer to inner */ Map desc_loop_indexer; void VisitAttrs(AttrVisitor* v) { @@ -669,7 +669,7 @@ class TensorizeInfoNode : public Object { v->Visit("desc_loop_indexer", &desc_loop_indexer); } - static constexpr const char* _type_key = "tir.analysis.TensorizeInfo"; + static constexpr const char* _type_key = "tir.schedule.TensorizeInfo"; TVM_DECLARE_FINAL_OBJECT_INFO(TensorizeInfoNode, Object); }; @@ -678,6 +678,13 @@ class TensorizeInfo : public ObjectRef { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(TensorizeInfo, ObjectRef, TensorizeInfoNode); }; +/*! + * \brief Establish a mapping between loops in a target block and an intrinsic description + * \param self The schedule state to be tensorized + * \param block_sref The target block to match against + * \param desc_func The prim func describing the computation to be tensorized + * \return TensorizeInfo structure if a valid mapping is found, NullOpt otherwise + */ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func); diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 739974252ebc..c9bee54073ca 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -23,7 +23,7 @@ from tvm.tir import Evaluate, For, ForKind, IndexMap, Var, decl_buffer, floordiv, floormod, Schedule from tvm.tir.analysis import expr_deep_equal -from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping +from tvm.tir.schedule.analysis import suggest_index_map, get_tensorize_loop_mapping, TensorizeInfo from tvm.script import tir as T from tvm.tir.stmt_functor import pre_order_visit from tvm.meta_schedule.testing import te_workload @@ -184,6 +184,8 @@ def test_get_tensorize_loop_mapping_dense_vnni(): info = get_tensorize_loop_mapping(s, block, dot_product_16x4_u8i8i32_desc) + assert isinstance(info, TensorizeInfo) + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) desc_loops = collect_loops(dot_product_16x4_u8i8i32_desc) From f0caa77b37acfddf016e0570aa08cfec5901d2f7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 18 Apr 2022 20:18:23 +0900 Subject: [PATCH 10/19] update --- src/tir/schedule/analysis/analysis.cc | 43 ++++++++++++++------------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 96598cbc1943..4c2d9ab8037b 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,8 @@ * specific language governing permissions and limitations * under the License. */ +#include + #include "../utils.h" namespace tvm { @@ -492,8 +494,7 @@ void CheckNotOutputBlock(const ScheduleState& self, const StmtSRef& block_sref, } } -std::vector GetBlockVarTypes(const StmtSRef& block_sref) { - const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); +std::vector GetBlockVarTypes(const BlockNode* block) { std::vector results; results.reserve(block->iter_vars.size()); for (const IterVar& iter_var : block->iter_vars) { @@ -502,6 +503,11 @@ std::vector GetBlockVarTypes(const StmtSRef& block_sref) { return results; } +std::vector GetBlockVarTypes(const StmtSRef& block_sref) { + const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); + return GetBlockVarTypes(block); +} + bool IsWriteCache(const StmtSRef& block_sref) { const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref); if (block->writes.size() != 1) { @@ -2033,11 +2039,6 @@ TVM_REGISTER_NODE_TYPE(TensorizeInfoNode); Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const tir::StmtSRef& block_sref, const tir::PrimFunc& desc_func) { - // Try to do tiling automatically if possible - // Now the heuristic is that if block's block var binding is constant + loop var, - // in other words, with tir.block(..., vi=Ci+i, vj=Cj+j, vk=Ck+k), then we split and reorder - // i, j, k according to the loops outside desc_block - // Collect the loops outside block arith::Analyzer analyzer; const tir::BlockRealize& block = tir::GetBlockRealize(self, block_sref); // Step 1. Analyze desc_func, extract its block, loops and loop vars @@ -2054,7 +2055,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, desc_block = block; return false; } - // Extract the loops + // Extract loops if (const auto* loop = obj.as()) { desc_loops.push_back(loop); desc_loop_vars.insert(loop->loop_var.get()); @@ -2068,8 +2069,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, std::reverse(desc_loops.begin(), desc_loops.end()); ICHECK(desc_block); } - // Step 2. Check if `desc_block` matches `block` - // Ignore the scope of buffers when comparing, since we can do cache_read/write + // Step 2. Collect loops from block_sref const tir::StmtSRef& scope_sref = GetScopeRoot(self, block_sref, false); const tir::BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_block, scope_sref); std::vector block_loops; @@ -2088,32 +2088,33 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, } std::reverse(block_loops.begin(), block_loops.end()); } - // Step 4. Map from block loops to desc block loops + // Step 3. Map from block loops to desc block loops ObjectPtr ret = make_object(); - int n_block_vars = block->iter_values.size(); - int n_desc_vars = desc_block->iter_values.size(); - int offset = n_block_vars - n_desc_vars; + const int n_block_vars = block->iter_values.size(); + const int n_desc_vars = desc_block->iter_values.size(); + const int offset = n_block_vars - n_desc_vars; + if (offset < 0) { return NullOpt; } - std::vector iter_types_desc; - for (const IterVar& iter_var : desc_block->block->iter_vars) { - iter_types_desc.push_back(iter_var->iter_type); - } + const std::vector iter_types_block = GetBlockVarTypes(block_sref); + const std::vector iter_types_desc = GetBlockVarTypes(desc_block->block.get()); - std::vector iter_types_block = GetBlockVarTypes(block_sref); + ICHECK(iter_types_block.size() == iter_types_desc.size()); ICHECK(block_loops.size() == iter_types_block.size()); int next_block_ind = block_loops.size() - 1; for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { const tir::ForNode* desc_loop = desc_loops[i_desc]; - const auto* int_desc_extent = desc_loop->extent.as(); + const IntImmNode* int_desc_extent = desc_loop->extent.as(); + if (!int_desc_extent) continue; for (int i_block = next_block_ind; i_block >= 0; --i_block) { const tir::ForNode* block_loop = block_loops[i_block]; - const auto* int_block_extent = block_loop->extent.as(); + const IntImmNode* int_block_extent = block_loop->extent.as(); + if (!int_block_extent) continue; if (int_block_extent->value % int_desc_extent->value != 0) continue; if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue; From 0df73cb061e0cc23e4376111e51f0424753b73cd Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Apr 2022 06:59:25 +0900 Subject: [PATCH 11/19] fixd condition --- src/tir/schedule/analysis/analysis.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 4c2d9ab8037b..efd3db46ca09 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2101,7 +2101,7 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const std::vector iter_types_block = GetBlockVarTypes(block_sref); const std::vector iter_types_desc = GetBlockVarTypes(desc_block->block.get()); - ICHECK(iter_types_block.size() == iter_types_desc.size()); + ICHECK(desc_loops.size() == static_cast(n_desc_vars)); ICHECK(block_loops.size() == iter_types_block.size()); int next_block_ind = block_loops.size() - 1; From 46eed2a7e9aea415d41abeeae9229d5048f50a47 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Apr 2022 07:00:09 +0900 Subject: [PATCH 12/19] black --- python/tvm/tir/schedule/analysis.py | 1 + 1 file changed, 1 insertion(+) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 817792ccd179..7122d8c68813 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -65,6 +65,7 @@ def suggest_index_map( @tvm._ffi.register_object("tir.schedule.TensorizeInfo") class TensorizeInfo(Object): """Necessary information used for tensorization.""" + pass From ecb3ebc8a7e30c7ee3c60ff502a30812e2924222 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Apr 2022 07:37:28 +0900 Subject: [PATCH 13/19] pylint --- python/tvm/tir/schedule/analysis.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 7122d8c68813..459931228a77 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -17,6 +17,9 @@ """Analysis used in TensorIR scheduling""" from typing import List, Optional +import tvm._ffi +from tvm.runtime import Object + from ..buffer import Buffer from ..stmt import For from ..expr import PrimExpr @@ -25,9 +28,6 @@ from . import _ffi_api from .schedule import Schedule, BlockRV -import tvm._ffi -from tvm.runtime import Object - def suggest_index_map( buffer: Buffer, @@ -66,8 +66,6 @@ def suggest_index_map( class TensorizeInfo(Object): """Necessary information used for tensorization.""" - pass - def get_tensorize_loop_mapping(sch: Schedule, block: BlockRV, desc_func: PrimFunc) -> TensorizeInfo: """Establish a mapping between loops in a target block and an intrinsic description From 0860abc7c3ddd5c8e7864b813841408ba8631000 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Apr 2022 13:03:07 +0900 Subject: [PATCH 14/19] Update python/tvm/tir/schedule/analysis.py Co-authored-by: Ruihang Lai --- python/tvm/tir/schedule/analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index 459931228a77..a5bbd58d014d 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -67,7 +67,7 @@ class TensorizeInfo(Object): """Necessary information used for tensorization.""" -def get_tensorize_loop_mapping(sch: Schedule, block: BlockRV, desc_func: PrimFunc) -> TensorizeInfo: +def get_tensorize_loop_mapping(sch: Schedule, block: BlockRV, desc_func: PrimFunc) -> Optional[TensorizeInfo]: """Establish a mapping between loops in a target block and an intrinsic description Parameters From ec39b62d59a6fb5301854ccabf2d3c522da44d76 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 19 Apr 2022 13:03:33 +0900 Subject: [PATCH 15/19] run black --- python/tvm/tir/schedule/analysis.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/tvm/tir/schedule/analysis.py b/python/tvm/tir/schedule/analysis.py index a5bbd58d014d..71ff024217c7 100644 --- a/python/tvm/tir/schedule/analysis.py +++ b/python/tvm/tir/schedule/analysis.py @@ -67,7 +67,9 @@ class TensorizeInfo(Object): """Necessary information used for tensorization.""" -def get_tensorize_loop_mapping(sch: Schedule, block: BlockRV, desc_func: PrimFunc) -> Optional[TensorizeInfo]: +def get_tensorize_loop_mapping( + sch: Schedule, block: BlockRV, desc_func: PrimFunc +) -> Optional[TensorizeInfo]: """Establish a mapping between loops in a target block and an intrinsic description Parameters From 9ec0974d24763ee7e43900a16ec6b44564f80704 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Apr 2022 05:50:28 +0900 Subject: [PATCH 16/19] bring back logic in original code to support loop permutation --- src/tir/schedule/analysis/analysis.cc | 59 +++++++++++++++---- .../unittest/test_tir_schedule_analysis.py | 24 ++++---- 2 files changed, 62 insertions(+), 21 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index efd3db46ca09..73cd283cb2de 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -16,6 +16,7 @@ * specific language governing permissions and limitations * under the License. */ +#include #include #include "../utils.h" @@ -2106,22 +2107,58 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, int next_block_ind = block_loops.size() - 1; for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { - const tir::ForNode* desc_loop = desc_loops[i_desc]; + // Step 4.2. Find the corresponding loop of the i-th block var of desc + const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; + const tir::ForNode* desc_loop = nullptr; + IterVarType iter_type_desc; + for (int i = 0, n = desc_loops.size(); i < n; ++i) { + // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars + PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(r, + [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { + desc_loop = desc_loops[i]; + iter_type_desc = iter_types_desc[i]; + break; + } + } + if (desc_loop == nullptr || desc_loop->extent.as() == nullptr) { + return NullOpt; + } + const IntImmNode* int_desc_extent = desc_loop->extent.as(); - if (!int_desc_extent) continue; + const tir::ForNode* block_loop = nullptr; + + PrimExpr block_bind; for (int i_block = next_block_ind; i_block >= 0; --i_block) { - const tir::ForNode* block_loop = block_loops[i_block]; - const IntImmNode* int_block_extent = block_loop->extent.as(); + if (iter_types_block[i_block] == iter_type_desc) { + next_block_ind = i_block - 1; + block_bind = block->iter_values[i_block]; + break; + } + } - if (!int_block_extent) continue; - if (int_block_extent->value % int_desc_extent->value != 0) continue; - if (iter_types_block[i_block] != iter_types_desc[i_desc]) continue; + for (int i = 0, n = block_loops.size(); i < n; ++i) { + PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (!UsesVar(r, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { + block_loop = block_loops[i]; + const IntImmNode* int_block_extent = block_loop->extent.as(); - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); - next_block_ind = i_block - 1; - break; + if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; + } + + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + auto it = ret->loop_map.find(block_loop_sref); + if (it == ret->loop_map.end()) { + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + } else if ((*it).second.get() != desc_loop) { + return NullOpt; + } + + break; + } } } diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index c9bee54073ca..9cae3e9815b5 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -240,20 +240,24 @@ def matmul_16x16x16xf16f16f16_desc( s = Schedule(matmul) block = s.get_block("C") + i0, i1, i2 = s.get_loops(block) + desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) - info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) - - desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) + for do_reorder in [True, False]: + # Mapping should be invariant to the loop permutation + if do_reorder: + s.reorder(i2, i0, i1) - desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) - i0, i1, i2 = s.get_loops(block) + info = get_tensorize_loop_mapping(s, block, matmul_16x16x16xf16f16f16_desc) + assert info is not None + desc_loop_to_sref = dict((v, k) for k, v in info.loop_map.items()) - for i in range(3): - assert desc_loops[i] in desc_loop_to_sref + for i in range(3): + assert desc_loops[i] in desc_loop_to_sref - assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) - assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) - assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) + assert s.get(desc_loop_to_sref[desc_loops[0]]) == s.get(i0) + assert s.get(desc_loop_to_sref[desc_loops[1]]) == s.get(i1) + assert s.get(desc_loop_to_sref[desc_loops[2]]) == s.get(i2) if __name__ == "__main__": From 2909a068dc6ac95cce0b81df58a28793a150d682 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Apr 2022 06:24:44 +0900 Subject: [PATCH 17/19] add comment --- src/tir/schedule/analysis/analysis.cc | 41 ++++++++++++++++++++------- 1 file changed, 30 insertions(+), 11 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 73cd283cb2de..17e442354542 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2105,12 +2105,28 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, ICHECK(desc_loops.size() == static_cast(n_desc_vars)); ICHECK(block_loops.size() == iter_types_block.size()); + // We assume that the orders of iter_vars in the target and the desc block are consistent. + // Based on that assumption, the following logic supports arbitrary permutations of a loop order, + // such as + + // for k: + // for i: + // for j: + // C[i, j] += A[i, k] * B[k, j] + + // or + + // for i: + // for j: + // for k: + // C[i, j] += A[i, k] * B[k, j] + int next_block_ind = block_loops.size() - 1; for (int i_desc = n_desc_vars - 1; i_desc >= 0; --i_desc) { - // Step 4.2. Find the corresponding loop of the i-th block var of desc + // Step 3.1. Find the corresponding loop of the i_desc-th block var of desc const PrimExpr& desc_bind = desc_block->iter_values[i_desc]; const tir::ForNode* desc_loop = nullptr; - IterVarType iter_type_desc; + IterVarType iter_type_desc = iter_types_desc[i_desc]; for (int i = 0, n = desc_loops.size(); i < n; ++i) { // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); @@ -2127,29 +2143,32 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, const IntImmNode* int_desc_extent = desc_loop->extent.as(); - const tir::ForNode* block_loop = nullptr; - + // Step 3.2. Find the corresponding iter_value of the target block with a matching iterator type PrimExpr block_bind; - for (int i_block = next_block_ind; i_block >= 0; --i_block) { - if (iter_types_block[i_block] == iter_type_desc) { - next_block_ind = i_block - 1; - block_bind = block->iter_values[i_block]; + for (int i = next_block_ind; i >= 0; --i) { + if (iter_types_block[i] == iter_type_desc) { + next_block_ind = i - 1; + block_bind = block->iter_values[i]; break; } } + if (!block_bind.defined()) return NullOpt; + + // Step 3.3. Find the corresponding loop of the target block for (int i = 0, n = block_loops.size(); i < n; ++i) { + // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); if (!UsesVar(r, [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { - block_loop = block_loops[i]; - const IntImmNode* int_block_extent = block_loop->extent.as(); + const IntImmNode* int_block_extent = block_loops[i]->extent.as(); + // Check divisibility if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { return NullOpt; } - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loops[i]]; auto it = ret->loop_map.find(block_loop_sref); if (it == ret->loop_map.end()) { ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); From f474003591405bd3d998bfdd440afc00b20d5bea Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Apr 2022 07:16:04 +0900 Subject: [PATCH 18/19] simplify --- src/tir/schedule/analysis/analysis.cc | 37 +++++++++++++-------------- 1 file changed, 18 insertions(+), 19 deletions(-) diff --git a/src/tir/schedule/analysis/analysis.cc b/src/tir/schedule/analysis/analysis.cc index 17e442354542..4777ee2657b3 100644 --- a/src/tir/schedule/analysis/analysis.cc +++ b/src/tir/schedule/analysis/analysis.cc @@ -2129,8 +2129,8 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, IterVarType iter_type_desc = iter_types_desc[i_desc]; for (int i = 0, n = desc_loops.size(); i < n; ++i) { // Check if desc_bind = loops[i]->loop_var + stuff-irrelevant-of-loop-vars - PrimExpr r = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); - if (!UsesVar(r, + PrimExpr residual = analyzer.Simplify(desc_bind - desc_loops[i]->loop_var); + if (!UsesVar(residual, [&desc_loop_vars](const VarNode* var) { return desc_loop_vars.count(var); })) { desc_loop = desc_loops[i]; iter_type_desc = iter_types_desc[i]; @@ -2158,26 +2158,25 @@ Optional GetTensorizeLoopMapping(const tir::ScheduleState& self, // Step 3.3. Find the corresponding loop of the target block for (int i = 0, n = block_loops.size(); i < n; ++i) { // Check if block_bind = block_loops[i]->loop_var + stuff-irrelevant-of-loop-vars - PrimExpr r = analyzer.Simplify(block_bind - block_loops[i]->loop_var); - if (!UsesVar(r, - [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) { - const IntImmNode* int_block_extent = block_loops[i]->extent.as(); - - // Check divisibility - if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { - return NullOpt; - } + const tir::ForNode* block_loop = block_loops[i]; + const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loop]; + // Skip i-th loop if it has already been mapped + if (ret->loop_map.find(block_loop_sref) != ret->loop_map.end()) continue; - const tir::StmtSRef& block_loop_sref = self->stmt2ref[block_loops[i]]; - auto it = ret->loop_map.find(block_loop_sref); - if (it == ret->loop_map.end()) { - ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); - } else if ((*it).second.get() != desc_loop) { - return NullOpt; - } + PrimExpr residual = analyzer.Simplify(block_bind - block_loops[i]->loop_var); + if (UsesVar(residual, + [&block_loop_vars](const VarNode* var) { return block_loop_vars.count(var); })) + continue; - break; + const IntImmNode* int_block_extent = block_loops[i]->extent.as(); + + // Check divisibility + if (!int_block_extent || int_block_extent->value % int_desc_extent->value != 0) { + return NullOpt; } + + ret->loop_map.Set(block_loop_sref, GetRef(desc_loop)); + break; } } From 8750b4d171a7b13ca9ab2670efa4fc1039da6301 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 20 Apr 2022 07:36:03 +0900 Subject: [PATCH 19/19] minor fix to test --- tests/python/unittest/test_tir_schedule_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_schedule_analysis.py b/tests/python/unittest/test_tir_schedule_analysis.py index 9cae3e9815b5..10371d3ccaf1 100644 --- a/tests/python/unittest/test_tir_schedule_analysis.py +++ b/tests/python/unittest/test_tir_schedule_analysis.py @@ -243,7 +243,7 @@ def matmul_16x16x16xf16f16f16_desc( i0, i1, i2 = s.get_loops(block) desc_loops = collect_loops(matmul_16x16x16xf16f16f16_desc) - for do_reorder in [True, False]: + for do_reorder in [False, True]: # Mapping should be invariant to the loop permutation if do_reorder: s.reorder(i2, i0, i1)