diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 55704cf4a97d..2c9da1df9dae 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -187,6 +187,21 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, bool use_software_pipeline); + /*! + * \brief Extension of MultiLevelTiling for backends with wide vectors. + * The loop over the innermost spatial axis of the output buffer is always vectorized with the + * maximum vector length. + * \param structure The tiling structure. 'SSRSRS' is recommended. + * \param vector_length_in_bits The length of a vector register in bits. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingWideVector( + String structure, Integer vector_length_in_bits, Optional max_innermost_factor, + Optional> reuse_read, Optional> reuse_write); + /*! * \brief Create a rule: add-rfactor to some blocks if needed * \param max_jobs_per_core The maximum number of jobs to be launched per CPU core. It sets the diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index dd0119b0a7f8..a015d0eb1ab2 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -28,6 +28,7 @@ MultiLevelTilingWithIntrin, ReuseType, MultiLevelTilingTensorCore, + MultiLevelTilingWideVector, ) from .parallel_vectorize_unroll import ParallelizeVectorizeUnroll from .random_compute_location import RandomComputeLocation diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 6703bc5716e9..e91382dd017a 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -187,3 +187,40 @@ def __init__( reuse_write.as_dict() if reuse_write is not None else None, use_software_pipeline, ) + + +@register_object("meta_schedule.MultiLevelTilingWideVector") +class MultiLevelTilingWideVector(ScheduleRule): + """Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost + spatial axis of the output buffer is always vectorized with the maximum vector length. + + Parameters + ---------- + structure : str + The tiling structure. 'SSRSRS' is recommended. + vector_length_in_bits: int + The length of a vector register in bits. + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + """ + + def __init__( + self, + structure: str, + vector_length_in_bits: int, + max_innermost_factor: Optional[int] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingWideVector, # type: ignore # pylint: disable=no-member + structure, + vector_length_in_bits, + max_innermost_factor, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + ) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index 1625a27b9aaf..2ae6714f55d8 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -166,6 +166,17 @@ std::vector MultiLevelTilingNode::AddWriteReuse(State state) const { return results; } +Array MultiLevelTilingNode::SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, + int n_tiles) const { + Array factors = sch->SamplePerfectTile( + /*loop=*/loop, + /*n=*/n_tiles, + /*max_innermost_factor=*/max_innermost_factor); + Array splits = sch->Split(/*loop=*/loop, + /*factors=*/{factors.begin(), factors.end()}); + return splits; +} + std::vector MultiLevelTilingNode::TileLoopNest(State state) const { Schedule& sch = state->sch; const BlockRV& block_rv = state->block_rv; @@ -179,6 +190,7 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { for (int i = 0, n = loops.size(); i < n; ++i) { LoopRV loop = loops[i]; const std::vector* idx = nullptr; + if (iter_types[i] == IterVarType::kDataPar) { idx = &s_indices_; if (spatial_loop_product != -1) { @@ -193,17 +205,18 @@ std::vector MultiLevelTilingNode::TileLoopNest(State state) const { } else { continue; } - // Do the split - int n_tiles = idx->size(); - Array factors = sch->SamplePerfectTile( - /*loop=*/loop, - /*n=*/n_tiles, - /*max_innermost_factor=*/max_innermost_factor); - Array splits = sch->Split(/*loop=*/loop, - /*factors=*/{factors.begin(), factors.end()}); - // Put every tile to its slot - for (int j = 0; j < n_tiles; ++j) { - tiles[idx->at(j)].push_back(splits[j]); + + const int n_tiles = idx->size(); + + if (n_tiles == 1) { + tiles[idx->at(0)].push_back(loop); + } else { + auto splits = SplitLoop(sch, block_rv, loop, n_tiles); + + // Put every tile to its slot + for (int j = 0; j < n_tiles; ++j) { + tiles[idx->at(j)].push_back(splits[j]); + } } } // Step 3. Reorder to organize the tiles diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.h b/src/meta_schedule/schedule_rule/multi_level_tiling.h index 47da878c3be0..8f55e8e7e4e4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.h +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.h @@ -161,6 +161,9 @@ class MultiLevelTilingNode : public ScheduleRuleNode { protected: virtual std::vector ApplySubRules(std::vector states); + virtual Array SplitLoop(const tir::Schedule& sch, tir::BlockRV block, + tir::LoopRV loop, int n_tiles) const; + // Annotate a block to use cooperative fetching void AnnotateCooperativeFetching(tir::Schedule* sch, const tir::BlockRV& block) const; diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc new file mode 100644 index 000000000000..f5ec009a9b28 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_wide_vector.cc @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../tir/schedule/analysis.h" +#include "../../tir/schedule/transform.h" +#include "../utils.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; +using tir::Schedule; + +/*! + * \brief Extension of MultiLevelTiling for backends with wide vectors. + * The loop over the innermost spatial axis of the output buffer is always vectorized with the + * maximum vector length. + */ +class MultiLevelTilingWideVectorNode : public MultiLevelTilingNode { + public: + size_t vector_length_in_bits; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWideVector"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWideVectorNode, MultiLevelTilingNode); + + protected: + Array SplitLoop(const Schedule& sch, BlockRV block, LoopRV loop, int n_tiles) const; +}; + +Array MultiLevelTilingWideVectorNode::SplitLoop(const Schedule& sch, BlockRV block_rv, + LoopRV loop_rv, int n_tiles) const { + const tir::ForNode* loop = TVM_SREF_TO_FOR(sch->GetSRef(loop_rv)); + const tir::StmtSRef block_sref = sch->GetSRef(block_rv); + const tir::BlockNode* block_node = block_sref->StmtAs(); + const tir::BlockRealize block_realize = tir::GetBlockRealize(sch->state(), block_sref); + ICHECK(block_node && block_node->writes.size() == 1); + + const auto out_dtype = block_node->writes[0]->buffer->dtype; + const int vec_len = vector_length_in_bits / out_dtype.bits(); + + // Determine if this loop is over the innermost axis of the output buffer. + // In the example below, we look for a loop whose loop var is bound to the axis co. + + // for (i0, 0, 1) { + // for (i1, 0, 56) { + // for (i2, 0, 56) { + // for (i3, 0, 64) { + // for (i4, 0, 3) { + // for (i5, 0, 3) { + // for (i6, 0, 64) { + // block conv2d_nhwc(...) { + // ... + // bind(co, i3) + // ... + // writes([conv2d_nhwc[n, h, w, co]]) + // ... + // conv2d_nhwc[n, h, w, co] = ... + // } + const size_t innermost_axis = block_node->writes[0]->region.size() - 1; + const PrimExpr innermost_iter_value = block_realize->iter_values[innermost_axis]; + + if (!arith::Analyzer().CanProve(loop->loop_var == innermost_iter_value)) { + // If this is not the innermost spatial loop, split the loop in the normal way. + return MultiLevelTilingNode::SplitLoop(sch, block_rv, loop_rv, n_tiles); + } else { + // We split the innermost spatial loop in a way that always uses the maximum vector length. + const int64_t* extent_int = tir::GetLoopIntExtent(loop); + if (extent_int && *extent_int > vec_len) { + Array inner_splits = sch->Split(/*loop=*/loop_rv, + /*factors=*/{NullOpt, PrimExpr(vec_len)}); + Array outer_factors = sch->SamplePerfectTile( + /*loop=*/inner_splits[0], + /*n=*/n_tiles - 1, + /*max_innermost_factor=*/max_innermost_factor); + Array outer_splits = sch->Split( + /*loop=*/inner_splits[0], /*factors=*/{outer_factors.begin(), outer_factors.end()}); + outer_splits.push_back(inner_splits[1]); + return outer_splits; + } else { + Array factors(n_tiles - 1, PrimExpr(1)); + factors.push_back(loop->extent); + return sch->Split(/*loop=*/loop_rv, + /*factors=*/{factors.begin(), factors.end()}); + } + } +} + +ScheduleRule ScheduleRule::MultiLevelTilingWideVector( + String structure, Integer vector_length_in_bits, Optional max_innermost_factor, + Optional> reuse_read, Optional> reuse_write) { + auto node = MultiLevelTilingInitCommon( + structure, NullOpt, max_innermost_factor, NullOpt, reuse_read, reuse_write); + node->vector_length_in_bits = vector_length_in_bits->value; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingWideVectorNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWideVector") + .set_body_typed(ScheduleRule::MultiLevelTilingWideVector); + +} // namespace meta_schedule +} // namespace tvm diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py index 939ccbe54fa6..d9d078106333 100644 --- a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt.py @@ -16,7 +16,7 @@ # under the License. # pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring from tvm import meta_schedule as ms -from tvm import te +from tvm import te, target from tvm.meta_schedule.testing import te_workload from tvm.meta_schedule.testing.schedule_rule import get_rules from tvm.meta_schedule.testing.space_generation import check_sketches @@ -521,9 +521,115 @@ def sum_with_trivial_block_iter( assert not sch.trace.simplified(remove_postproc=True).insts +def test_multi_level_tiling_hexagon(): + @T.prim_func + def cpu_conv2d_nhwc( + inputs: T.Buffer[(1, 56, 56, 64), "float16"], + weight: T.Buffer[(3, 3, 64, 64), "float16"], + conv2d_nhwc: T.Buffer[(1, 56, 56, 64), "float16"], + ) -> None: + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + PadInput = T.alloc_buffer([1, 58, 58, 64], dtype="float16") + for i0, i1, i2, i3 in T.grid(1, 58, 58, 64): + with T.block("PadInput"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1]) + T.writes(PadInput[i0_1, i1_1, i2_1, i3_1]) + PadInput[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 1 <= i1_1 and i1_1 < 57 and 1 <= i2_1 and i2_1 < 57, + inputs[i0_1, i1_1 - 1, i2_1 - 1, i3_1], + T.float16(0), + dtype="float16", + ) + for ( + i0_0, + i1_0, + i2_0, + i3_0, + i4_0, + i5_0, + i6_0, + i0_1_1, + i1_1_1, + i2_1_1, + i3_1_1, + i4_1, + i5_1, + i6_1, + i0_2, + i1_2, + i2_2, + i3_2, + ) in T.grid(1, 1, 2, 1, 3, 3, 16, 1, 14, 2, 1, 1, 1, 4, 1, 4, 14, 64): + with T.block("conv2d_nhwc"): + n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_0) + h = T.axis.spatial(56, i1_0 * 56 + i1_1_1 * 4 + i1_2) + w = T.axis.spatial(56, i2_0 * 28 + i2_1_1 * 14 + i2_2) + co = T.axis.spatial(64, i3_0 * 64 + i3_1_1 * 64 + i3_2) + rh = T.axis.reduce(3, i4_1 + i4_0) + rw = T.axis.reduce(3, i5_0 + i5_1) + rc = T.axis.reduce(64, i6_0 * 4 + i6_1) + T.reads(PadInput[n, h + rh, w + rw, co // 64 * 64 + rc], weight[rh, rw, rc, co]) + T.writes(conv2d_nhwc[n, h, w, co]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + with T.init(): + conv2d_nhwc[n, h, w, co] = T.float16(0) + conv2d_nhwc[n, h, w, co] = ( + conv2d_nhwc[n, h, w, co] + + PadInput[n, h + rh, w + rw, co // 64 * 64 + rc] * weight[rh, rw, rc, co] + ) + + target_hexagon = target.hexagon("v69", num_cores=4) + + I = 64 + O = 64 + H = 56 + W = 56 + + mod = te.create_prim_func( + te_workload.conv2d_nhwc(1, H, W, I, O, 3, 1, 1, 1, in_dtype="float16", out_dtype="float16") + ) + + actual = ms.TuneContext( + mod=mod, + target=Target(target_hexagon, host=target_hexagon), + space_generator=ms.space_generator.PostOrderApply(), + sch_rules=[ + ms.schedule_rule.MultiLevelTilingWideVector( + structure="SRSRS", + vector_length_in_bits=1024, + max_innermost_factor=64, + reuse_read=None, + reuse_write=None, + ) + ], + task_name="test", + ).generate_design_space() + + decision_0 = [ + ("SamplePerfectTile", [1, 1, 1]), + ("SamplePerfectTile", [1, 14, 4]), + ("SamplePerfectTile", [2, 2, 14]), + ("SamplePerfectTile", [3, 1]), + ("SamplePerfectTile", [3, 1]), + ("SamplePerfectTile", [16, 4]), + ] + + check_sketches( + mod, + sketches=actual, + expected_mods=[cpu_conv2d_nhwc], + expected_decisions=[decision_0], + ) + + if __name__ == "__main__": test_cpu_matmul() test_cpu_matmul_relu() test_cuda_matmul() test_cuda_matmul_relu() test_cuda_sum_with_trivial_block_iter() + test_multi_level_tiling_hexagon()