From ec7796b9b8d087c81f7177e085aacd4c028962c3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sun, 18 Sep 2022 07:49:14 +0900 Subject: [PATCH 1/4] [Metaschedule] Introduce MultiLevelTiling for wide vector architecture --- include/tvm/meta_schedule/schedule_rule.h | 15 +++ .../meta_schedule/schedule_rule/__init__.py | 1 + .../schedule_rule/multi_level_tiling.py | 37 ++++++ .../schedule_rule/multi_level_tiling.cc | 35 +++-- .../schedule_rule/multi_level_tiling.h | 3 + .../multi_level_tiling_wide_vector.cc | 120 ++++++++++++++++++ .../test_meta_schedule_schedule_rule_mlt.py | 53 ++++++++ 7 files changed, 253 insertions(+), 11 deletions(-) create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 55704cf4a97d..2c9da1df9dae 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -187,6 +187,21 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, bool use_software_pipeline); + /*! + * \brief Extension of MultiLevelTiling for backends with wide vectors. + * The loop over the innermost spatial axis of the output buffer is always vectorized with the + * maximum vector length. + * \param structure The tiling structure. 'SSRSRS' is recommended. + * \param vector_length_in_bits The length of a vector register in bits. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingWideVector( + String structure, Integer vector_length_in_bits, Optional max_innermost_factor, + Optional> reuse_read, Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index dd0119b0a7f8..a015d0eb1ab2 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -28,6 +28,7 @@ MultiLevelTilingWithIntrin, ReuseType, MultiLevelTilingTensorCore, + MultiLevelTilingWideVector, ) from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 6703bc5716e9..e91382dd017a 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -187,3 +187,40 @@ def __init__( reuse_write.as_dict() if reuse_write is not None else None, use_software_pipeline, ) + + +@register_object("meta_schedule.MultiLevelTilingWideVector") +class MultiLevelTilingWideVector(ScheduleRule): + """Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost + spatial axis of the output buffer is always vectorized with the maximum vector length. + + Parameters + ---------- + structure : str + The tiling structure. 'SSRSRS' is recommended. + vector_length_in_bits: int + The length of a vector register in bits. + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + vector_length_in_bits: int, + max_innermost_factor: Optional[int] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingWideVector, # type: ignore # pylint: disable=no-member + structure, + vector_length_in_bits, + max_innermost_factor, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 1625a27b9aaf..f5715e812418 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -166,6 +166,17 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } +Array MultiLevelTilingNode::SplitLoop(Schedule& sch, BlockRV block, LoopRV loop, + int n_tiles) const { + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + return splits; +} + std::vector MultiLevelTilingNode::TileLoopNest(State state) const { Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; @@ -179,6 +190,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { idx = &s_indices_; if (spatial_loop_product != -1) { @@ -193,17 +205,18 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { } else { continue; } - // Do the split - int n_tiles = idx->size(); - Array factors = sch->SamplePerfectTile( - /*loop=*/loop, - /*n=*/n_tiles, - /*max_innermost_factor=*/max_innermost_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); - // Put every tile to its slot - for (int j = 0; j < n_tiles; ++j) { - tiles[idx->at(j)].push_back(splits[j]); + + const int n_tiles = idx->size(); + + if (n_tiles == 1) { + tiles[idx->at(0)].push_back(loop); + } else { + auto splits = SplitLoop(sch, block_rv, loop, n_tiles); + + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } } } // Step 3. Reorder to organize the tiles diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 47da878c3be0..a1281f26ad33 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -161,6 +161,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); + virtual Array SplitLoop(tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, + int n_tiles) const; + // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc new file mode 100644 index 000000000000..3be5b70accd7 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../tir/schedule/analysis.h" +#include "../../tir/schedule/transform.h" +#include "../utils.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Extension of MultiLevelTiling for backends with wide vectors. + * The loop over the innermost spatial axis of the output buffer is always vectorized with the + * maximum vector length. + */ +class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { + public: + size_t vector_length_in_bits; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode); + + protected: + Array SplitLoop(Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; +}; + +Array MultiLevelTilingWideVectorNode::SplitLoop(Schedule& sch, BlockRV block_rv, + LoopRV loop_rv, int n_tiles) const { + const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); + const tir::StmtSRef block_sref = sch->GetSRef(block_rv); + const tir::BlockNode* block_node = block_sref->StmtAs(); + const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref); + ICHECK(block_node && block_node->writes.size() == 1); + + const auto out_dtype = block_node->writes[0]->buffer->dtype; + const int vec_len = vector_length_in_bits / out_dtype.bits(); + + // Determine if this loop is over the innermost axis of the output buffer. + // In the example below, we look for a loop whose loop var is bound to the axis co. + + // for (i0, 0, 1) { + // for (i1, 0, 56) { + // for (i2, 0, 56) { + // for (i3, 0, 64) { + // for (i4, 0, 3) { + // for (i5, 0, 3) { + // for (i6, 0, 64) { + // block conv2d_nhwc(...) { + // ... + // bind(co, i3) + // ... + // writes([conv2d_nhwc[n, h, w, co]]) + // ... + // conv2d_nhwc[n, h, w, co] = ... + // } + const size_t innermost_axis = block_node->writes[0]->region.size() - 1; + const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis]; + + if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) { + // If this is not the innermost spatial loop, split the loop in the normal way. + return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles); + } else { + // We split the innermost spatial loop in a way that always uses the maximum vector length. + const int64_t* extent_int = tir::GetLoopIntExtent(loop); + if (extent_int && *extent_int > vec_len) { + Array inner_splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{NullOpt, PrimExpr(vec_len)}); + Array outer_factors = sch->SamplePerfectTile( + /*loop=*/inner_splits[0], + /*n=*/n_tiles - 1, + /*max_innermost_factor=*/max_innermost_factor); + Array outer_splits = sch->Split( + /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); + outer_splits.push_back(inner_splits[1]); + return outer_splits; + } else { + Array factors(n_tiles - 1, PrimExpr(1)); + factors.push_back(loop->extent); + return sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); + } + } +} + +ScheduleRule ScheduleRule::MultiLevelTilingWideVector( + String structure, Integer vector_length_in_bits, Optional max_innermost_factor, + Optional> reuse_read, Optional> reuse_write) { + auto node = MultiLevelTilingInitCommon( + structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write); + node->vector_length_in_bits = vector_length_in_bits->value; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector") + .set_body_typed(ScheduleRule::MultiLevelTilingWideVector); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 939ccbe54fa6..d56aa337e033 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -521,6 +521,59 @@ def sum_with_trivial_block_iter( assert not sch.trace.simplified(remove_postproc=True).insts +def test_multi_level_tiling_hexagon(): + target_hexagon = tvm.target.hexagon("v69", num_cores=4) + target = tvm.target.Target(target_hexagon, host=target_hexagon) + + I = 64 + O = 64 + H = 56 + W = 56 + + data, kernel, out = te_workload.conv2d_nhwc( + 1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16" + ) + workload = te.create_prim_func([data, kernel, out]) + + ctx = _create_context( + workload, + target=target, + rule=schedule_rule.MultiLevelTilingWideVector( + structure="SRSRS", + vector_length_in_bits=1024, + max_innermost_factor=64, + reuse_read=None, + reuse_write=None, + ), + ) + spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) + assert len(spaces) == 1 + + expected = [ + """b0 = sch.get_block(name="conv2d_nhwc", func_name="main") +sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SRSRS") +l1, l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b0) +v8, v9, v10 = sch.sample_perfect_tile(loop=l1, n=3, max_innermost_factor=64) +l11, l12, l13 = sch.split(loop=l1, factors=[v8, v9, v10], preserve_unit_iters=True) +v14, v15, v16 = sch.sample_perfect_tile(loop=l2, n=3, max_innermost_factor=64) +l17, l18, l19 = sch.split(loop=l2, factors=[v14, v15, v16], preserve_unit_iters=True) +v20, v21, v22 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64) +l23, l24, l25 = sch.split(loop=l3, factors=[v20, v21, v22], preserve_unit_iters=True) +l26, l27, l28 = sch.split(loop=l4, factors=[1, 1, 64], preserve_unit_iters=True) +v29, v30 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64) +l31, l32 = sch.split(loop=l5, factors=[v29, v30], preserve_unit_iters=True) +v33, v34 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64) +l35, l36 = sch.split(loop=l6, factors=[v33, v34], preserve_unit_iters=True) +v37, v38 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64) +l39, l40 = sch.split(loop=l7, factors=[v37, v38], preserve_unit_iters=True) +sch.reorder(l11, l17, l23, l26, l31, l35, l39, l12, l18, l24, l27, l32, l36, l40, l13, l19, l25, l28)""".split( + "\n" + ) + ] + + check_trace(spaces, expected) + + if __name__ == "__main__": test_cpu_matmul() test_cpu_matmul_relu() From 5f6885a2facc08d9a9186dde2162128f02d78243 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Sep 2022 15:35:59 +0900 Subject: [PATCH 2/4] update test --- .../test_meta_schedule_schedule_rule_mlt.py | 91 ++++++++++++------- 1 file changed, 56 insertions(+), 35 deletions(-) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index d56aa337e033..02827ed1a134 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms -from tvm import te +from tvm import te, target from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.schedule_rule import get_rules from tvm.meta_schedule.testing.space_generation import check_sketches @@ -135,6 +135,7 @@ def cpu_matmul_2( sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), task_name="test", ).generate_design_space() + check_sketches( mod, sketches=actual, @@ -522,56 +523,75 @@ def sum_with_trivial_block_iter( def test_multi_level_tiling_hexagon(): - target_hexagon = tvm.target.hexagon("v69", num_cores=4) - target = tvm.target.Target(target_hexagon, host=target_hexagon) + @T.prim_func + def cpu_conv2d_nhwc(inputs: T.Buffer[(1, 56, 56, 64), "float16"], weight: T.Buffer[(3, 3, 64, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"]) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") + for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i6_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2, i3_2 in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): + with T.block("conv2d_nhwc"): + n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0) + h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2) + w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2) + co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2) + rh = T.axis.reduce(3, i4_1 + i4_0) + rw = T.axis.reduce(3, i5_0 + i5_1) + rc = T.axis.reduce(64, i6_0 * 4 + i6_1) + T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co]) + T.writes(conv2d_nhwc[n, h, w, co]) + T.block_attr({"meta_schedule.tiling_structure":"SRSRS"}) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float16(0) + conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co] + + target_hexagon = target.hexagon("v69", num_cores=4) I = 64 O = 64 H = 56 W = 56 - data, kernel, out = te_workload.conv2d_nhwc( + mod = te.create_prim_func(te_workload.conv2d_nhwc( 1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16" - ) - workload = te.create_prim_func([data, kernel, out]) + )) - ctx = _create_context( - workload, - target=target, - rule=schedule_rule.MultiLevelTilingWideVector( + actual = ms.TuneContext( + mod=mod, + target=Target(target_hexagon, host=target_hexagon), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ms.schedule_rule.MultiLevelTilingWideVector( structure="SRSRS", vector_length_in_bits=1024, max_innermost_factor=64, reuse_read=None, reuse_write=None, - ), - ) - spaces = ctx.space_generator.generate_design_space(mod=ctx.mod) - assert len(spaces) == 1 + )], + task_name="test", + ).generate_design_space() - expected = [ - """b0 = sch.get_block(name="conv2d_nhwc", func_name="main") -sch.annotate(block_or_loop=b0, ann_key="meta_schedule.tiling_structure", ann_val="SRSRS") -l1, l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b0) -v8, v9, v10 = sch.sample_perfect_tile(loop=l1, n=3, max_innermost_factor=64) -l11, l12, l13 = sch.split(loop=l1, factors=[v8, v9, v10], preserve_unit_iters=True) -v14, v15, v16 = sch.sample_perfect_tile(loop=l2, n=3, max_innermost_factor=64) -l17, l18, l19 = sch.split(loop=l2, factors=[v14, v15, v16], preserve_unit_iters=True) -v20, v21, v22 = sch.sample_perfect_tile(loop=l3, n=3, max_innermost_factor=64) -l23, l24, l25 = sch.split(loop=l3, factors=[v20, v21, v22], preserve_unit_iters=True) -l26, l27, l28 = sch.split(loop=l4, factors=[1, 1, 64], preserve_unit_iters=True) -v29, v30 = sch.sample_perfect_tile(loop=l5, n=2, max_innermost_factor=64) -l31, l32 = sch.split(loop=l5, factors=[v29, v30], preserve_unit_iters=True) -v33, v34 = sch.sample_perfect_tile(loop=l6, n=2, max_innermost_factor=64) -l35, l36 = sch.split(loop=l6, factors=[v33, v34], preserve_unit_iters=True) -v37, v38 = sch.sample_perfect_tile(loop=l7, n=2, max_innermost_factor=64) -l39, l40 = sch.split(loop=l7, factors=[v37, v38], preserve_unit_iters=True) -sch.reorder(l11, l17, l23, l26, l31, l35, l39, l12, l18, l24, l27, l32, l36, l40, l13, l19, l25, l28)""".split( - "\n" - ) + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1]), + ("SamplePerfectTile", [1, 14, 4]), + ("SamplePerfectTile", [2, 2, 14]), + ("SamplePerfectTile", [3, 1]), + ("SamplePerfectTile", [3, 1]), + ("SamplePerfectTile", [16, 4]), ] - check_trace(spaces, expected) + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_conv2d_nhwc], + expected_decisions=[decision_0], + ) if __name__ == "__main__": @@ -580,3 +600,4 @@ def test_multi_level_tiling_hexagon(): test_cuda_matmul() test_cuda_matmul_relu() test_cuda_sum_with_trivial_block_iter() + test_multi_level_tiling_hexagon() From 55ca83a764027c4394df597e8ac074f2335fa1aa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Tue, 20 Sep 2022 15:36:57 +0900 Subject: [PATCH 3/4] format --- .../schedule_rule/multi_level_tiling.cc | 2 +- .../test_meta_schedule_schedule_rule_mlt.py | 64 ++++++++++++++----- 2 files changed, 49 insertions(+), 17 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index f5715e812418..7348e2002d9c 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -215,7 +215,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { // Put every tile to its slot for (int j = 0; j < n_tiles; ++j) { - tiles[idx->at(j)].push_back(splits[j]); + tiles[idx->at(j)].push_back(splits[j]); } } } diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 02827ed1a134..d9d078106333 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -135,7 +135,6 @@ def cpu_matmul_2( sch_rules=get_rules("llvm", ms.schedule_rule.MultiLevelTiling), task_name="test", ).generate_design_space() - check_sketches( mod, sketches=actual, @@ -524,7 +523,11 @@ def sum_with_trivial_block_iter( def test_multi_level_tiling_hexagon(): @T.prim_func - def cpu_conv2d_nhwc(inputs: T.Buffer[(1, 56, 56, 64), "float16"], weight: T.Buffer[(3, 3, 64, 64), "float16"], conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"]) -> None: + def cpu_conv2d_nhwc( + inputs: T.Buffer[(1, 56, 56, 64), "float16"], + weight: T.Buffer[(3, 3, 64, 64), "float16"], + conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"], + ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) # body @@ -535,8 +538,32 @@ def cpu_conv2d_nhwc(inputs: T.Buffer[(1, 56, 56, 64), "float16"], weight: T.Buff i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) - PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], T.float16(0), dtype="float16") - for i0_0, i1_0, i2_0, i3_0, i4_0, i5_0, i6_0, i0_1_1, i1_1_1, i2_1_1, i3_1_1, i4_1, i5_1, i6_1, i0_2, i1_2, i2_2, i3_2 in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, + inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], + T.float16(0), + dtype="float16", + ) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i4_0, + i5_0, + i6_0, + i0_1_1, + i1_1_1, + i2_1_1, + i3_1_1, + i4_1, + i5_1, + i6_1, + i0_2, + i1_2, + i2_2, + i3_2, + ) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): with T.block("conv2d_nhwc"): n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0) h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2) @@ -547,10 +574,13 @@ def cpu_conv2d_nhwc(inputs: T.Buffer[(1, 56, 56, 64), "float16"], weight: T.Buff rc = T.axis.reduce(64, i6_0 * 4 + i6_1) T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co]) T.writes(conv2d_nhwc[n, h, w, co]) - T.block_attr({"meta_schedule.tiling_structure":"SRSRS"}) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) with T.init(): conv2d_nhwc[n, h, w, co] = T.float16(0) - conv2d_nhwc[n, h, w, co] = conv2d_nhwc[n, h, w, co] + PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co] + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co] + ) target_hexagon = target.hexagon("v69", num_cores=4) @@ -559,21 +589,23 @@ def cpu_conv2d_nhwc(inputs: T.Buffer[(1, 56, 56, 64), "float16"], weight: T.Buff H = 56 W = 56 - mod = te.create_prim_func(te_workload.conv2d_nhwc( - 1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16" - )) + mod = te.create_prim_func( + te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16") + ) actual = ms.TuneContext( mod=mod, target=Target(target_hexagon, host=target_hexagon), space_generator=ms.space_generator.PostOrderApply(), - sch_rules=[ms.schedule_rule.MultiLevelTilingWideVector( - structure="SRSRS", - vector_length_in_bits=1024, - max_innermost_factor=64, - reuse_read=None, - reuse_write=None, - )], + sch_rules=[ + ms.schedule_rule.MultiLevelTilingWideVector( + structure="SRSRS", + vector_length_in_bits=1024, + max_innermost_factor=64, + reuse_read=None, + reuse_write=None, + ) + ], task_name="test", ).generate_design_space() From a221bf8a89f7b8bba5bc5a829b00fa7429dcea32 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 21 Sep 2022 04:21:53 +0900 Subject: [PATCH 4/4] cpplint --- src/meta_schedule/schedule_rule/multi_level_tiling.cc | 2 +- src/meta_schedule/schedule_rule/multi_level_tiling.h | 4 ++-- .../schedule_rule/multi_level_tiling_wide_vector.cc | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 7348e2002d9c..2ae6714f55d8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -166,7 +166,7 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } -Array MultiLevelTilingNode::SplitLoop(Schedule& sch, BlockRV block, LoopRV loop, +Array MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const { Array factors = sch->SamplePerfectTile( /*loop=*/loop, diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index a1281f26ad33..8f55e8e7e4e4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -161,8 +161,8 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); - virtual Array SplitLoop(tir::Schedule& sch, tir::BlockRV block, tir::LoopRV loop, - int n_tiles) const; + virtual Array SplitLoop(const tir::Schedule& sch, tir::BlockRV block, + tir::LoopRV loop, int n_tiles) const; // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc index 3be5b70accd7..f5ec009a9b28 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -42,10 +42,10 @@ class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode); protected: - Array SplitLoop(Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; + Array SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; }; -Array MultiLevelTilingWideVectorNode::SplitLoop(Schedule& sch, BlockRV block_rv, +Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, LoopRV loop_rv, int n_tiles) const { const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); const tir::StmtSRef block_sref = sch->GetSRef(block_rv);