From 84dc7b9beb6938595bae56c4fc9eec473f4b6ffd Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Fri, 6 Jan 2023 17:04:12 -0800 Subject: [PATCH 1/4] [MetaSchedule][Hexagon] Add MultiLevelTilingHexagon to schedule async pipelines that utilize DMA --- include/tvm/meta_schedule/schedule_rule.h | 25 + include/tvm/tir/data_type_rewriter.h | 1 + .../meta_schedule/schedule_rule/__init__.py | 1 + .../schedule_rule/multi_level_tiling.py | 54 + python/tvm/tir/tensor_intrin/hexagon.py | 59 +- .../multi_level_tiling_hexagon.cc | 145 ++ .../multi_level_tiling_with_intrin.cc | 75 +- .../multi_level_tiling_with_intrin.h | 68 + src/tir/ir/data_type_rewriter.cc | 23 +- src/tir/ir/stmt.cc | 4 + .../contrib/test_hexagon/test_conv2d_async.py | 184 +++ ...meta_schedule_schedule_rule_mlt_hexagon.py | 1177 +++++++++++++++++ 12 files changed, 1746 insertions(+), 70 deletions(-) create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc create mode 100644 src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.h create mode 100644 tests/python/contrib/test_hexagon/test_conv2d_async.py create mode 100644 tests/python/unittest/test_meta_schedule_schedule_rule_mlt_hexagon.py diff --git a/include/tvm/meta_schedule/schedule_rule.h b/include/tvm/meta_schedule/schedule_rule.h index 7995d1fceeb6..b758dcb48223 100644 --- a/include/tvm/meta_schedule/schedule_rule.h +++ b/include/tvm/meta_schedule/schedule_rule.h @@ -210,6 +210,31 @@ class ScheduleRule : public runtime::ObjectRef { Optional> vector_load_lens, Optional> reuse_read, Optional> reuse_write, bool use_software_pipeline); + /*! + * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate + * tensor core intrinsics + * \param intrin_groups A list of groups of tensor core intrinsics. The map should contain key + * "compute" which represents the tensor intrin for computation. The value of the map should be + * names of tensor intrinsics, must be registered via + * TensorIntrin.register(...) beforehand + * \param structure The tiling structure. Recommended: + * - 'SRSRS' on Hexagon + * \param tile_binds For each level of tiles, which thread axis it is bound to. These are not + * supported on hexagon. + * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit + * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. + * NullOpt means disable vectorization + * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. + * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. + * \param use_software_pipeline Whether use the software pipeline. + * \return The schedule rule created + */ + TVM_DLL static ScheduleRule MultiLevelTilingHexagon( + Array> intrin_groups, String structure, + Optional> tile_binds, Optional max_innermost_factor, + Optional> vector_load_lens, Optional> reuse_read, + Optional> reuse_write, bool use_software_pipeline); + /*! * \brief Extension of MultiLevelTiling for backends with wide vectors. * The loop over the innermost spatial axis of the output buffer is always vectorized with the diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 5f72f75ede41..9cd6eae02028 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -108,6 +108,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; + Stmt VisitStmt_(const AllocateConstNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; PrimExpr VisitExpr_(const LTNode* op) override; diff --git a/python/tvm/meta_schedule/schedule_rule/__init__.py b/python/tvm/meta_schedule/schedule_rule/__init__.py index d330fc713991..92ade43e7438 100644 --- a/python/tvm/meta_schedule/schedule_rule/__init__.py +++ b/python/tvm/meta_schedule/schedule_rule/__init__.py @@ -27,6 +27,7 @@ from .multi_level_tiling import ( MultiLevelTiling, MultiLevelTilingTensorCore, + MultiLevelTilingHexagon, MultiLevelTilingWideVector, MultiLevelTilingWithIntrin, ReuseType, diff --git a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py index 19651a2ce18e..7274e05050d2 100644 --- a/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py +++ b/python/tvm/meta_schedule/schedule_rule/multi_level_tiling.py @@ -197,6 +197,60 @@ def __init__( ) +@register_object("meta_schedule.MultiLevelTilingHexagon") +class MultiLevelTilingHexagon(ScheduleRule): + """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate hexagon + intrinsics. + + Parameters + ---------- + intrin_groups : List[Mapping[str, str]] + A list of groups of tensor core intrinsics. The map should contain key + "compute" which represents the tensor intrin for computation. The value of the map should be + names of tensor intrinsics, must be registered via + TensorIntrin.register(...) beforehand + structure : str + The tiling structure. Recommended: + - 'SRSRS' on Hexagon + tile_bind : Optional[List[str]] + For each level of tiles, which thread axis it is bound to. Not supported on Hexagon. + max_innermost_factor : Optional[int] + The maximum size of the innermost factor. None means no limit + vector_load_lens : Optional[List[int]] + The length of vector lane in vectorized cooperative fetching. + None means disable vectorization + reuse_read : Optional[ReuseType] + Data reuse configuration for reading. None means no reuse. + reuse_write : Optional[ReuseType] + Data reuse configuration for writing. None means no reuse. + use_software_pipeline : bool + Whether to use the software pipeline. + """ + + def __init__( + self, + intrin_groups: List[Mapping[str, str]], + structure: str, + tile_binds: Optional[List[str]] = None, + max_innermost_factor: Optional[int] = None, + vector_load_lens: Optional[List[int]] = None, + reuse_read: Optional[ReuseType] = None, + reuse_write: Optional[ReuseType] = None, + use_software_pipeline: bool = False, + ) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScheduleRuleMultiLevelTilingHexagon, # type: ignore # pylint: disable=no-member + intrin_groups, + structure, + tile_binds, + max_innermost_factor, + vector_load_lens, + reuse_read.as_dict() if reuse_read is not None else None, + reuse_write.as_dict() if reuse_write is not None else None, + use_software_pipeline, + ) + + @register_object("meta_schedule.MultiLevelTilingWideVector") class MultiLevelTilingWideVector(ScheduleRule): """Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost diff --git a/python/tvm/tir/tensor_intrin/hexagon.py b/python/tvm/tir/tensor_intrin/hexagon.py index 7a348f3f1a45..c67510c1e2ff 100644 --- a/python/tvm/tir/tensor_intrin/hexagon.py +++ b/python/tvm/tir/tensor_intrin/hexagon.py @@ -16,6 +16,7 @@ # under the License. # pylint: disable=invalid-name,missing-function-docstring """Intrinsics for Hexagon tensorization.""" + from tvm.script import tir as T from .. import TensorIntrin @@ -68,12 +69,12 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None: return sync_dma_load_desc, sync_dma_load_impl -def generate_dot_product_32x4_u8u8i32(mem_scope="global"): +def generate_dot_product_32x4_u8u8i32(read_mem_scope="global", write_mem_scope="global"): @T.prim_func def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -85,9 +86,9 @@ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None @T.prim_func def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -110,12 +111,12 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy -def generate_dot_product_32x4_u8i8i32(mem_scope="global"): +def generate_dot_product_32x4_u8i8i32(read_mem_scope="global", write_mem_scope="global"): @T.prim_func def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -127,9 +128,9 @@ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None @T.prim_func def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:4], B[0:32, 0:4]) T.writes(C[0:32]) @@ -152,12 +153,12 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy -def generate_dot_product_32x2_i16i16i32(mem_scope="global"): +def generate_dot_product_32x2_i16i16i32(read_mem_scope="global", write_mem_scope="global"): @T.prim_func def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:2], B[0:32, 0:2]) T.writes(C[0:32]) @@ -169,9 +170,9 @@ def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> No @T.prim_func def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) - B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) - C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) + A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=read_mem_scope) + B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=read_mem_scope) + C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) with T.block("root"): T.reads(C[0:32], A[0:2], B[0:32, 0:2]) T.writes(C[0:32]) @@ -207,10 +208,22 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32()) VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy" -TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm")) +TensorIntrin.register( + VRMPY_u8u8i32_VTCM_INTRIN, + *generate_dot_product_32x4_u8u8i32("global.vtcm", "global.vtcm"), +) + +VRMPY_u8u8i32_VTCM_READS_INTRIN = "dot_32x4_u8u8i32_vtcm_reads_vrmpy" +TensorIntrin.register( + VRMPY_u8u8i32_VTCM_READS_INTRIN, + *generate_dot_product_32x4_u8u8i32("global.vtcm", "global"), +) VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy" -TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm")) +TensorIntrin.register( + VRMPY_u8i8i32_VTCM_INTRIN, + *generate_dot_product_32x4_u8i8i32("global.vtcm", "global.vtcm"), +) DMA_READ_128_u8 = "dma_read_128_u8" TensorIntrin.register(DMA_READ_128_u8, *generate_dma_load_intrin(128, "uint8")) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc new file mode 100644 index 000000000000..c8636dc5b429 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc @@ -0,0 +1,145 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../../tir/schedule/analysis.h" +#include "../../tir/schedule/transform.h" +#include "../utils.h" +#include "multi_level_tiling_with_intrin.h" + +namespace tvm { +namespace meta_schedule { + +using tir::BlockRV; +using tir::LoopRV; +using tir::Schedule; + +class MultiLevelTilingHexagonNode : public MultiLevelTilingWithIntrinNode { + private: + // Subrule: Add software pipeline + inline std::vector AddSoftwarePipeline(State state) const; + + // Override ApplySubRules to apply tensorization-specific sub-rules + std::vector ApplySubRules(std::vector states) final; + + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const override { + ObjectPtr n = make_object(*this); + return ScheduleRule(n); + } + + public: + /*! \brief Whether to use software pipeline */ + bool use_software_pipeline = false; + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingHexagon"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingHexagonNode, MultiLevelTilingNode); +}; + +std::vector MultiLevelTilingHexagonNode::ApplySubRules(std::vector states) { + states = MultiLevelTilingWithIntrinNode::ApplySubRules(states); + states = SubRule(std::move(states), [&](State state) { return AddSoftwarePipeline(state); }); + return states; +} + +std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) const { + if (!use_software_pipeline) { + return {state}; + } + // The current config is not suitable for software pipelining. + if (r_indices_.size() < 2) { + return {state}; + } + + Schedule& sch = state->sch; + // Check reduction length after blockize. + int64_t reduction_length = 1; + for (int r_index : r_indices_) { + const Array& tiles = state->tiles[r_index]; + for (const LoopRV& tile : tiles) { + const auto* extent = sch->Get(tile)->extent.as(); + ICHECK(extent != nullptr) << "Dynamic extent is not supported."; + reduction_length *= extent->value; + } + } + if (reduction_length <= 1) { + return {state}; + } + + // Return if there are more less than 1 or more than 2 cache_reads. + size_t cache_read_count = state->read_reuse.size(); + if (cache_read_count > 2 || cache_read_count == 0) { + return {state}; + } + + // Add annotations for software pipelining at the loop right above the cache read stages. + Array software_pipeline_stage; + Array software_pipeline_order; + Array software_pipeline_async_stages; + if (cache_read_count == 2) { + software_pipeline_stage = Array{0, 0, 1}; + software_pipeline_order = Array{0, 1, 2}; + software_pipeline_async_stages = Array{0}; + } else { + software_pipeline_stage = Array{0, 1}; + software_pipeline_order = Array{0, 1}; + software_pipeline_async_stages = Array{0}; + } + + tir::BlockRV cache_read_block = state->read_reuse.begin()->second; + Array cache_read_loops = sch->GetLoops(cache_read_block); + Array reduction_loops; + for (size_t i = 0; i < cache_read_loops.size() - 1; ++i) { + if (tir::GetLoopIterType(sch->GetSRef(cache_read_loops[i])) != tir::IterVarType::kDataPar) { + reduction_loops.push_back(cache_read_loops[i]); + } else if (reduction_loops.size() > 0 && + sch->Get(cache_read_loops[i])->extent.as()->value == 1) { + reduction_loops.push_back(cache_read_loops[i]); + } + } + auto fused = sch->Fuse(reduction_loops); + + sch->Annotate(fused, tir::attr::software_pipeline_stage, software_pipeline_stage); + sch->Annotate(fused, tir::attr::software_pipeline_order, software_pipeline_order); + sch->Annotate(fused, tir::attr::software_pipeline_async_stages, software_pipeline_async_stages); + + // TODO(nverke): Add support for nested async pipelines. + // TODO(nverke): Add support for async cache writes. + + return {state}; +} + +ScheduleRule ScheduleRule::MultiLevelTilingHexagon( + Array> intrin_groups, String structure, Optional> tile_binds, + Optional max_innermost_factor, Optional> vector_load_lens, + Optional> reuse_read, Optional> reuse_write, + bool use_software_pipeline) { + CHECK(!tile_binds.defined()) << "Tile binds cannot be used on hexagon."; + auto node = MultiLevelTilingInitCommon( + structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); + + node->intrin_name = intrin_groups[0]["compute"]; + node->use_software_pipeline = use_software_pipeline; + return ScheduleRule(node); +} + +TVM_REGISTER_NODE_TYPE(MultiLevelTilingHexagonNode); +TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingHexagon") + .set_body_typed(ScheduleRule::MultiLevelTilingHexagon); + +} // namespace meta_schedule +} // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc index 428a1206a4ca..3eae6d402200 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.cc @@ -17,10 +17,14 @@ * under the License. */ +#include "multi_level_tiling_with_intrin.h" + +#include +#include +#include + #include "../../tir/schedule/analysis.h" -#include "../../tir/schedule/transform.h" #include "../utils.h" -#include "multi_level_tiling.h" namespace tvm { namespace meta_schedule { @@ -41,55 +45,35 @@ Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, return outer_block; } -/*! - * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. - */ -class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { - protected: - Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final { - auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; - if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { - TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; - return {sch}; - } - - auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv); - - if (res.empty()) { - TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; - return {sch}; - } - TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name; - return res; +Array MultiLevelTilingWithIntrinNode::Apply(const tir::Schedule& sch, + const tir::BlockRV& block_rv) { + auto desc_func = tir::TensorIntrin::Get(intrin_name).value()->desc; + if (!CheckAutoTensorizeApplicable(sch, block_rv, desc_func)) { + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; + return {sch}; } - // Inherited from ScheduleRuleNode - ScheduleRule Clone() const final { - ObjectPtr n = - make_object(*this); - return ScheduleRule(n); - } + auto res = MultiLevelTilingNode::Apply(sch->Copy(), block_rv); - // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then - // tile the outerloops. - virtual std::vector ApplySubRules(std::vector states) { - states = SubRule(std::move(states), [&](State state) { - if (auto block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name)) { - state->block_rv = block_rv.value(); - return std::vector(1, state); - } - return std::vector(); - }); - return MultiLevelTilingNode::ApplySubRules(states); + if (res.empty()) { + TVM_PY_LOG(INFO, logger) << "The workload cannot be tensorized."; + return {sch}; } + TVM_PY_LOG(INFO, logger) << "Tensorizing with " << intrin_name; + return res; +} - public: - /*! \brief The name of a tensor intrinsic. */ - String intrin_name; +std::vector MultiLevelTilingWithIntrinNode::ApplySubRules(std::vector states) { + states = SubRule(std::move(states), [&](State state) { + if (auto block_rv = TileForIntrin(state->sch, state->block_rv, intrin_name)) { + state->block_rv = block_rv.value(); + return std::vector(1, state); + } + return std::vector(); + }); - static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; - TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); -}; + return MultiLevelTilingNode::ApplySubRules(states); +} ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( String intrin_name, String structure, Optional> tile_binds, @@ -106,6 +90,5 @@ ScheduleRule ScheduleRule::MultiLevelTilingWithIntrin( TVM_REGISTER_NODE_TYPE(MultiLevelTilingWithIntrinNode); TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingWithIntrin") .set_body_typed(ScheduleRule::MultiLevelTilingWithIntrin); - } // namespace meta_schedule } // namespace tvm diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.h b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.h new file mode 100644 index 000000000000..dccf99dc1fc4 --- /dev/null +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_with_intrin.h @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_WITH_INTRIN_H_ +#define TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_WITH_INTRIN_H_ + +#include +#include + +#include "../../tir/schedule/analysis.h" +#include "../utils.h" +#include "multi_level_tiling.h" + +namespace tvm { +namespace meta_schedule { + +/*! + * \brief Tile a subset of loops in the block according to the given tensor intrinsic, and annotate + * the tiled block for tensorization by postproc rewrite. + */ +Optional TileForIntrin(tir::Schedule sch, tir::BlockRV block, + const std::string& intrin_name); + +/*! + * \brief Extension of MultiLevelTiling for auto-tensorizing with a single intrinsic. + */ +class MultiLevelTilingWithIntrinNode : public MultiLevelTilingNode { + protected: + Array Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) override; + + // Inherited from ScheduleRuleNode + ScheduleRule Clone() const override { + ObjectPtr n = + make_object(*this); + return ScheduleRule(n); + } + + // Override ApplySubRules to tile the inner loops according to the given tensor intrinsic, then + // tile the outerloops. + std::vector ApplySubRules(std::vector states) override; + + public: + /*! \brief The name of a tensor intrinsic. */ + String intrin_name; + + static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingWithIntrin"; + TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingWithIntrinNode, MultiLevelTilingNode); +}; + +} // namespace meta_schedule +} // namespace tvm + +#endif // TVM_META_SCHEDULE_SCHEDULE_RULE_MULTI_LEVEL_TILING_WITH_INTRIN_H_ diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index 8da7cfdd5b97..ca589957294b 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -244,6 +244,23 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { } } +Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateConstNode* op) { + bool is_enabled = is_enabled_; + is_enabled_ = true; + auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); + is_enabled_ = is_enabled; + auto new_body = this->VisitStmt(op->body); + if (!new_extents.same_as(op->extents) || !new_body.same_as(op->body)) { + AllocateConst new_allocate = GetRef(op); + auto* n = new_allocate.CopyOnWrite(); + n->extents = std::move(new_extents); + n->body = std::move(new_body); + return std::move(new_allocate); + } else { + return GetRef(op); + } +} + Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { Buffer new_buffer = VisitBuffer(op->buffer); DeclBuffer decl_buffer = Downcast(StmtExprMutator::VisitStmt_(op)); @@ -379,6 +396,10 @@ IterVar IndexDataTypeRewriter::VisitIterVar(const IterVar& iter_var) { } Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { + if (auto it = buffer_remap_.find(buffer); it != buffer_remap_.end()) { + return (*it).second; + } + bool is_enabled = is_enabled_; is_enabled_ = true; @@ -404,7 +425,7 @@ Buffer IndexDataTypeRewriter::VisitBuffer(const Buffer& buffer) { } BufferRegion IndexDataTypeRewriter::VisitBufferRegion(const BufferRegion& buffer_region) { - Buffer remapped_buffer = GetRemappedBuffer(buffer_region->buffer); + Buffer remapped_buffer = VisitBuffer(buffer_region->buffer); bool is_enabled = is_enabled_; is_enabled_ = true; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1652786cb510..59f98c1ad1a8 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -672,6 +672,10 @@ BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block blo << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression"; ObjectPtr node = make_object(); + for (size_t i = 0; i < values.size(); ++i) { + ICHECK(block->iter_vars[i]->var.dtype() == values[i].dtype()); + } + node->iter_values = std::move(values); node->predicate = std::move(predicate); node->block = std::move(block); diff --git a/tests/python/contrib/test_hexagon/test_conv2d_async.py b/tests/python/contrib/test_hexagon/test_conv2d_async.py new file mode 100644 index 000000000000..5e92a49f8613 --- /dev/null +++ b/tests/python/contrib/test_hexagon/test_conv2d_async.py @@ -0,0 +1,184 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=missing-docstring +""" Test rpc based launcher for hexagon """ +import tempfile + +import numpy as np +import pytest +import tvm.testing +import tvm.topi.testing +from tvm import meta_schedule as ms +from tvm import relay +from tvm.contrib.hexagon.meta_schedule import ( + get_hexagon_local_builder, + get_hexagon_rpc_runner, +) +from tvm.meta_schedule import postproc, schedule_rule +from tvm.tir.tensor_intrin.hexagon import ( + VRMPY_u8u8i32_VTCM_READS_INTRIN, +) + +from .infrastructure import get_hexagon_target + + +def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): + sch_rules_async = [ + schedule_rule.ApplyCustomRule(), + schedule_rule.AutoInline( + into_producer=False, + into_consumer=True, + inline_const_tensor=True, + disallow_if_then_else=True, + require_injective=True, + require_ordered=True, + disallow_op=["tir.exp"], + ), + schedule_rule.MultiLevelTilingHexagon( + intrin_groups=[ + {"compute": VRMPY_u8u8i32_VTCM_READS_INTRIN}, + ], + structure="SRSRS", + tile_binds=None, + max_innermost_factor=64, # 64 // tensor intrin size + vector_load_lens=None, + reuse_read=ms.schedule_rule.ReuseType( + req="must", + levels=[2], + scope="global.vtcm", + ), + reuse_write=None, + use_software_pipeline=True, + ), + schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=-1, + max_vectorize_extent=-1, + unroll_max_steps=[8, 16, 32], + unroll_explicit=True, + ), + ] + + postprocs = [ + postproc.RewriteParallelVectorizeUnroll(), + postproc.RewriteReductionBlock(), + postproc.RewriteTensorize(vectorize_init_loop=True), + postproc.VerifyVTCMLimit(), + postproc.DisallowAsyncStridedMemCopy(merge_async_commit_queue_scope=False), + ] + + target = get_hexagon_target("v68") + executor = relay.backend.Executor("graph", {"link-params": True}) + mod = mod.with_attr("executor", executor) + + use_async = True + + if use_async: + config = { + "tir.use_async_copy": True, + "tir.merge_async_commit_queue_scope": False, + } + + ctx = tvm.transform.PassContext( + opt_level=3, + config=config, + ) + sch_rules = sch_rules_async + + with tempfile.TemporaryDirectory() as work_dir: + database = ms.relay_integration.tune_relay( + mod=mod, + target=target, + params=params, + work_dir=work_dir, + max_trials_global=20000, + max_trials_per_task=16, + num_trials_per_iter=16, + strategy="replay-trace", + builder=get_hexagon_local_builder(ctx), + runner=get_hexagon_rpc_runner(hexagon_launcher, number=20), + space=ms.space_generator.PostOrderApply( + sch_rules=sch_rules, + postprocs=postprocs, + mutator_probs={}, + ), + ) + + config.update( + { + "relay.backend.use_meta_schedule": True, + "relay.backend.tir_converter": "default", + } + ) + + return ms.relay_integration.compile_relay( + database=database, mod=mod, target=target, params=params, pass_config=config + ) + + +@tvm.testing.requires_hexagon +def test_conv2d_relay_auto_schedule(hexagon_launcher): + """Test conv2d using auto schedule.""" + if hexagon_launcher.is_simulator(): + pytest.skip(msg="Tuning on simulator not supported.") + + if tvm.testing.utils.IS_IN_CI: + pytest.skip("Skipping test since it takes too long in CI.") + + i_size, o_size, h_size, w_size = 64, 64, 56, 56 + k_height_size = k_width_size = 3 + + strides = (1, 1) + padding = (0, 0) + + d_shape = (1, i_size, h_size, w_size) + w_shape = (o_size, i_size, k_height_size, k_width_size) + bias_shape = (1, o_size, 1, 1) + + data = relay.var("data", shape=d_shape, dtype="uint8") + weight = relay.var("weight", shape=w_shape, dtype="uint8") + conv2d = relay.nn.conv2d( + data=data, + weight=weight, + kernel_size=(k_height_size, k_width_size), + channels=o_size, + padding=padding, + strides=strides, + out_dtype="int32", + ) + mod = tvm.IRModule.from_expr(conv2d) + + data_np = np.random.uniform(1, 10, size=d_shape).astype("uint8") + weight_np = np.random.uniform(1, 10, size=w_shape).astype("uint8") + bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32") + params = {"weight": weight_np, "bias": bias_np} + + ref = ( + relay.create_executor("graph", mod=mod, device=tvm.cpu(0), target="llvm") + .evaluate()(*[data_np, weight_np]) + .numpy() + ) + + lib = tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher) + + with hexagon_launcher.create_session() as session: + rt_mod = session.get_executor_from_factory(lib) + rt_mod.set_input("data", data_np) + rt_mod.run() + + out = rt_mod.get_output(0).numpy() + np.testing.assert_allclose(ref, out, atol=1e-4, rtol=1e-5) diff --git a/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_hexagon.py b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_hexagon.py new file mode 100644 index 000000000000..50eab8ddb944 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_schedule_rule_mlt_hexagon.py @@ -0,0 +1,1177 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-module-docstring,missing-function-docstring,missing-class-docstring +from tests.python.contrib.test_hexagon.test_meta_schedule import dense_compute +import tvm +from tvm.meta_schedule import schedule_rule +import tvm.testing +from tvm import meta_schedule as ms +from tvm import te +from tvm.meta_schedule.testing import te_workload +from tvm.meta_schedule.testing.space_generation import ( + check_sketches, + generate_design_space, + get_rules, +) +from tvm.script import tir as T +from tvm.tir.tensor_intrin.cuda import get_wmma_intrin_group +from tvm.tir.tensor_intrin.hexagon import VRMPY_u8u8i32_INTRIN, VRMPY_u8u8i32_VTCM_INTRIN + + +def multi_level_tiling_hexagon( + *, + write_reuse_scope="global.vtcm", + in_dtype="uint8", + out_dtype="int32", + use_software_pipeline=False, +) -> ms.schedule_rule.ScheduleRule: + assert write_reuse_scope in ["global", "global.vtcm"] + if not isinstance(in_dtype, list): + in_dtype = [in_dtype] + if not isinstance(out_dtype, list): + out_dtype = [out_dtype] + return ms.schedule_rule.MultiLevelTilingHexagon( + intrin_groups=[ + {"compute": VRMPY_u8u8i32_VTCM_INTRIN}, + ], + structure="SRSRS", + tile_binds=None, + max_innermost_factor=64, # 64 // tensor intrin size + vector_load_lens=None, + reuse_read=ms.schedule_rule.ReuseType( + req="must", + levels=[2], + scope="global.vtcm", + ), + reuse_write=ms.schedule_rule.ReuseType( + req="must" if write_reuse_scope == "shared" else "no", + levels=[1], + scope=write_reuse_scope, + ), + use_software_pipeline=use_software_pipeline, + ) + + +def test_dense_base(): + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packed_width: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + X_global_vtcm = T.alloc_buffer([128, 768], dtype="uint8", scope="global.vtcm") + packed_width_global_vtcm = T.alloc_buffer( + [24, 192, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i_0, j_0_0, k_0_0 in T.grid(128, 6, 48): + for ax0_ax1_fused in T.serial(16): + with T.block("X_global.vtcm"): + v0 = T.axis.spatial(128, i_0) + v1 = T.axis.spatial(768, k_0_0 * 16 + ax0_ax1_fused) + T.reads(X[v0, v1]) + T.writes(X_global_vtcm[v0, v1]) + X_global_vtcm[v0, v1] = X[v0, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("packed_width_global.vtcm"): + v0 = T.axis.spatial(24, j_0_0 * 4 + ax0_ax1_ax2_ax3_fused // 512) + v1 = T.axis.spatial(192, k_0_0 * 4 + ax0_ax1_ax2_ax3_fused % 512 // 128) + v2 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 128 // 4) + v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused % 4) + T.reads(packed_width[v0, v1, v2, v3]) + T.writes(packed_width_global_vtcm[v0, v1, v2, v3]) + packed_width_global_vtcm[v0, v1, v2, v3] = packed_width[v0, v1, v2, v3] + for i_1, j_0_1, k_0_1, i_2, j_0_2 in T.grid(1, 2, 4, 1, 2): + with T.block("compute_o"): + v_i = T.axis.spatial(128, i_0 + i_1 + i_2) + v_j_o = T.axis.spatial(24, j_0_0 * 4 + j_0_1 * 2 + j_0_2) + v_k_o = T.axis.reduce(192, k_0_0 * 4 + k_0_1) + T.reads( + X_global_vtcm[v_i, v_k_o * 4 : v_k_o * 4 + 4], + packed_width_global_vtcm[v_j_o, v_k_o, 0:32, 0:4], + ) + T.writes(compute[v_i, v_j_o * 32 : v_j_o * 32 + 32]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for j_1 in T.serial(32): + with T.block("compute_init"): + v_j_i_init = T.axis.spatial(32, j_1) + T.reads() + T.writes(compute[v_i, v_j_o * 32 + v_j_i_init]) + compute[v_i, v_j_o * 32 + v_j_i_init] = 0 + for j_1, k_1 in T.grid(32, 4): + with T.block("compute"): + v_j_i, v_k_i = T.axis.remap("SR", [j_1, k_1]) + T.reads( + compute[v_i, v_j_o * 32 + v_j_i], + X_global_vtcm[v_i, v_k_o * 4 + v_k_i], + packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i], + ) + T.writes(compute[v_i, v_j_o * 32 + v_j_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_i, v_j_o * 32 + v_j_i] = compute[ + v_i, v_j_o * 32 + v_j_i + ] + T.Cast("int32", X_global_vtcm[v_i, v_k_o * 4 + v_k_i]) * T.Cast( + "int32", packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [1, 4]), + ] + + mod = te.create_prim_func( + dense_compute( + m=128, + n=768, + k=768, + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_dense_with_fallback(): + + # from tvm.script import tir as T + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packed_width: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + X_global_vtcm = T.alloc_buffer([128, 768], dtype="uint8", scope="global.vtcm") + packed_width_global_vtcm = T.alloc_buffer( + [24, 192, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i_0, j_0_0, k_0_0 in T.grid(128, 6, 192): + for ax0_ax1_fused in T.serial(4): + with T.block("X_global.vtcm"): + v0 = T.axis.spatial(128, i_0) + v1 = T.axis.spatial(768, k_0_0 * 4 + ax0_ax1_fused) + T.reads(X[v0, v1]) + T.writes(X_global_vtcm[v0, v1]) + X_global_vtcm[v0, v1] = X[v0, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(512): + with T.block("packed_width_global.vtcm"): + v0 = T.axis.spatial(24, j_0_0 * 4 + ax0_ax1_ax2_ax3_fused // 128) + v1 = T.axis.spatial(192, k_0_0) + v2 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 128 // 4) + v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused % 4) + T.reads(packed_width[v0, v1, v2, v3]) + T.writes(packed_width_global_vtcm[v0, v1, v2, v3]) + packed_width_global_vtcm[v0, v1, v2, v3] = packed_width[v0, v1, v2, v3] + for i_1, j_0_1, k_0_1, i_2, j_0_2 in T.grid(1, 2, 1, 1, 2): + with T.block("compute_o"): + v_i = T.axis.spatial(128, i_0 + i_1 + i_2) + v_j_o = T.axis.spatial(24, j_0_0 * 4 + j_0_1 * 2 + j_0_2) + v_k_o = T.axis.reduce(192, k_0_1 + k_0_0) + T.reads( + X_global_vtcm[v_i, v_k_o * 4 : v_k_o * 4 + 4], + packed_width_global_vtcm[v_j_o, v_k_o, 0:32, 0:4], + ) + T.writes(compute[v_i, v_j_o * 32 : v_j_o * 32 + 32]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for j_1 in T.serial(32): + with T.block("compute_init"): + v_j_i_init = T.axis.spatial(32, j_1) + T.reads() + T.writes(compute[v_i, v_j_o * 32 + v_j_i_init]) + compute[v_i, v_j_o * 32 + v_j_i_init] = 0 + for j_1, k_1 in T.grid(32, 4): + with T.block("compute"): + v_j_i, v_k_i = T.axis.remap("SR", [j_1, k_1]) + T.reads( + compute[v_i, v_j_o * 32 + v_j_i], + X_global_vtcm[v_i, v_k_o * 4 + v_k_i], + packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i], + ) + T.writes(compute[v_i, v_j_o * 32 + v_j_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_i, v_j_o * 32 + v_j_i] = compute[ + v_i, v_j_o * 32 + v_j_i + ] + T.Cast("int32", X_global_vtcm[v_i, v_k_o * 4 + v_k_i]) * T.Cast( + "int32", packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [2, 1]), + ] + + mod = te.create_prim_func( + dense_compute( + m=128, + n=768, + k=768, + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_dense_with_pipeline(): + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packed_width: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + X_global_vtcm = T.alloc_buffer([128, 768], dtype="uint8", scope="global.vtcm") + packed_width_global_vtcm = T.alloc_buffer( + [24, 192, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i_0, j_0_0 in T.grid(128, 6): + for k_0_0_fused in T.serial( + 48, + annotations={ + "software_pipeline_async_stages": [0], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_ax1_fused in T.serial(16): + with T.block("X_global.vtcm"): + v0 = T.axis.spatial(128, i_0) + v1 = T.axis.spatial(768, k_0_0_fused * 16 + ax0_ax1_fused) + T.reads(X[v0, v1]) + T.writes(X_global_vtcm[v0, v1]) + X_global_vtcm[v0, v1] = X[v0, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("packed_width_global.vtcm"): + v0 = T.axis.spatial(24, j_0_0 * 4 + ax0_ax1_ax2_ax3_fused // 512) + v1 = T.axis.spatial( + 192, k_0_0_fused * 4 + ax0_ax1_ax2_ax3_fused % 512 // 128 + ) + v2 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 128 // 4) + v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused % 4) + T.reads(packed_width[v0, v1, v2, v3]) + T.writes(packed_width_global_vtcm[v0, v1, v2, v3]) + packed_width_global_vtcm[v0, v1, v2, v3] = packed_width[v0, v1, v2, v3] + for i_1, j_0_1, k_0_1, i_2, j_0_2 in T.grid(1, 2, 4, 1, 2): + with T.block("compute_o"): + v_i = T.axis.spatial(128, i_1 + i_2 + i_0) + v_j_o = T.axis.spatial(24, j_0_0 * 4 + j_0_1 * 2 + j_0_2) + v_k_o = T.axis.reduce(192, k_0_0_fused * 4 + k_0_1) + T.reads( + X_global_vtcm[v_i, v_k_o * 4 : v_k_o * 4 + 4], + packed_width_global_vtcm[v_j_o, v_k_o, 0:32, 0:4], + ) + T.writes(compute[v_i, v_j_o * 32 : v_j_o * 32 + 32]) + T.block_attr( + {"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"} + ) + with T.init(): + for j_1 in T.serial(32): + with T.block("compute_init"): + v_j_i_init = T.axis.spatial(32, j_1) + T.reads() + T.writes(compute[v_i, v_j_o * 32 + v_j_i_init]) + compute[v_i, v_j_o * 32 + v_j_i_init] = 0 + for j_1, k_1 in T.grid(32, 4): + with T.block("compute"): + v_j_i, v_k_i = T.axis.remap("SR", [j_1, k_1]) + T.reads( + compute[v_i, v_j_o * 32 + v_j_i], + X_global_vtcm[v_i, v_k_o * 4 + v_k_i], + packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i], + ) + T.writes(compute[v_i, v_j_o * 32 + v_j_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_i, v_j_o * 32 + v_j_i] = compute[ + v_i, v_j_o * 32 + v_j_i + ] + T.Cast("int32", X_global_vtcm[v_i, v_k_o * 4 + v_k_i]) * T.Cast( + "int32", packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [1, 4]), + ] + + mod = te.create_prim_func( + dense_compute( + m=128, + n=768, + k=768, + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(use_software_pipeline=True), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_dense_global(): + + # from tvm.script import tir as T + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packed_width: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + X_global_vtcm = T.alloc_buffer([128, 768], dtype="uint8", scope="global.vtcm") + packed_width_global_vtcm = T.alloc_buffer( + [24, 192, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i_0, j_0_0, k_0_0 in T.grid(128, 6, 192): + for ax0_ax1_fused in T.serial(4): + with T.block("X_global.vtcm"): + v0 = T.axis.spatial(128, i_0) + v1 = T.axis.spatial(768, k_0_0 * 4 + ax0_ax1_fused) + T.reads(X[v0, v1]) + T.writes(X_global_vtcm[v0, v1]) + X_global_vtcm[v0, v1] = X[v0, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(512): + with T.block("packed_width_global.vtcm"): + v0 = T.axis.spatial(24, j_0_0 * 4 + ax0_ax1_ax2_ax3_fused // 128) + v1 = T.axis.spatial(192, k_0_0) + v2 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 128 // 4) + v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused % 4) + T.reads(packed_width[v0, v1, v2, v3]) + T.writes(packed_width_global_vtcm[v0, v1, v2, v3]) + packed_width_global_vtcm[v0, v1, v2, v3] = packed_width[v0, v1, v2, v3] + for i_1, j_0_1, k_0_1, i_2, j_0_2 in T.grid(1, 2, 1, 1, 2): + with T.block("compute_o"): + v_i = T.axis.spatial(128, i_0 + i_1 + i_2) + v_j_o = T.axis.spatial(24, j_0_0 * 4 + j_0_1 * 2 + j_0_2) + v_k_o = T.axis.reduce(192, k_0_1 + k_0_0) + T.reads( + X_global_vtcm[v_i, v_k_o * 4 : v_k_o * 4 + 4], + packed_width_global_vtcm[v_j_o, v_k_o, 0:32, 0:4], + ) + T.writes(compute[v_i, v_j_o * 32 : v_j_o * 32 + 32]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for j_1 in T.serial(32): + with T.block("compute_init"): + v_j_i_init = T.axis.spatial(32, j_1) + T.reads() + T.writes(compute[v_i, v_j_o * 32 + v_j_i_init]) + compute[v_i, v_j_o * 32 + v_j_i_init] = 0 + for j_1, k_1 in T.grid(32, 4): + with T.block("compute"): + v_j_i, v_k_i = T.axis.remap("SR", [j_1, k_1]) + T.reads( + compute[v_i, v_j_o * 32 + v_j_i], + X_global_vtcm[v_i, v_k_o * 4 + v_k_i], + packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i], + ) + T.writes(compute[v_i, v_j_o * 32 + v_j_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_i, v_j_o * 32 + v_j_i] = compute[ + v_i, v_j_o * 32 + v_j_i + ] + T.Cast("int32", X_global_vtcm[v_i, v_k_o * 4 + v_k_i]) * T.Cast( + "int32", packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [2, 1]), + ] + + mod = te.create_prim_func( + dense_compute( + m=128, + n=768, + k=768, + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(write_reuse_scope="global"), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_padded_dense(): + @T.prim_func + def main( + X: T.Buffer[(128, 768), "uint8"], + packed_width: T.Buffer[(24, 192, 32, 4), "uint8"], + compute: T.Buffer[(128, 768), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + X_global_vtcm = T.alloc_buffer([128, 768], dtype="uint8", scope="global.vtcm") + packed_width_global_vtcm = T.alloc_buffer( + [24, 192, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i_0, j_0_0, k_0_0 in T.grid(128, 6, 48): + for ax0_ax1_fused in T.serial(16): + with T.block("X_global.vtcm"): + v0 = T.axis.spatial(128, i_0) + v1 = T.axis.spatial(768, k_0_0 * 16 + ax0_ax1_fused) + T.reads(X[v0, v1]) + T.writes(X_global_vtcm[v0, v1]) + X_global_vtcm[v0, v1] = X[v0, v1] + for ax0_ax1_ax2_ax3_fused in T.serial(2048): + with T.block("packed_width_global.vtcm"): + v0 = T.axis.spatial(24, j_0_0 * 4 + ax0_ax1_ax2_ax3_fused // 512) + v1 = T.axis.spatial(192, k_0_0 * 4 + ax0_ax1_ax2_ax3_fused % 512 // 128) + v2 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 128 // 4) + v3 = T.axis.spatial(4, ax0_ax1_ax2_ax3_fused % 4) + T.reads(packed_width[v0, v1, v2, v3]) + T.writes(packed_width_global_vtcm[v0, v1, v2, v3]) + packed_width_global_vtcm[v0, v1, v2, v3] = packed_width[v0, v1, v2, v3] + for i_1, j_0_1, k_0_1, i_2, j_0_2 in T.grid(1, 2, 4, 1, 2): + with T.block("compute_o"): + v_i = T.axis.spatial(128, i_0 + i_1 + i_2) + v_j_o = T.axis.spatial(24, j_0_0 * 4 + j_0_1 * 2 + j_0_2) + v_k_o = T.axis.reduce(192, k_0_0 * 4 + k_0_1) + T.reads( + X_global_vtcm[v_i, v_k_o * 4 : v_k_o * 4 + 4], + packed_width_global_vtcm[v_j_o, v_k_o, 0:32, 0:4], + ) + T.writes(compute[v_i, v_j_o * 32 : v_j_o * 32 + 32]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for j_1 in T.serial(32): + with T.block("compute_init"): + v_j_i_init = T.axis.spatial(32, j_1) + T.reads() + T.writes(compute[v_i, v_j_o * 32 + v_j_i_init]) + compute[v_i, v_j_o * 32 + v_j_i_init] = 0 + for j_1, k_1 in T.grid(32, 4): + with T.block("compute"): + v_j_i, v_k_i = T.axis.remap("SR", [j_1, k_1]) + T.reads( + compute[v_i, v_j_o * 32 + v_j_i], + X_global_vtcm[v_i, v_k_o * 4 + v_k_i], + packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i], + ) + T.writes(compute[v_i, v_j_o * 32 + v_j_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_i, v_j_o * 32 + v_j_i] = compute[ + v_i, v_j_o * 32 + v_j_i + ] + T.Cast("int32", X_global_vtcm[v_i, v_k_o * 4 + v_k_i]) * T.Cast( + "int32", packed_width_global_vtcm[v_j_o, v_k_o, v_j_i, v_k_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [1, 4]), + ] + + mod = te.create_prim_func( + dense_compute( + m=128, + n=768, + k=768, + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(write_reuse_scope="global"), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_conv2d(): + + # from tvm.script import tir as T + @T.prim_func + def main( + inputs: T.Buffer[(1, 16, 16, 32), "uint8"], + weight: T.Buffer[(3, 3, 32, 32), "uint8"], + conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="uint8") + PadInput_global_vtcm = T.alloc_buffer([1, 18, 18, 32], dtype="uint8", scope="global.vtcm") + weight_global_vtcm = T.alloc_buffer([3, 3, 32, 32], dtype="uint8", scope="global.vtcm") + for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else( + 1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, + inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], + T.uint8(0), + dtype="uint8", + ) + for n_0, h_0, w_0, co_0_0, rh_0, rw_0, rc_0_0 in T.grid(1, 4, 4, 1, 1, 1, 2): + for ax0_ax1_ax2_ax3_fused in T.serial(576): + with T.block("PadInput_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(18, h_0 * 4 + ax0_ax1_ax2_ax3_fused // 96) + v2 = T.axis.spatial(18, w_0 * 4 + ax0_ax1_ax2_ax3_fused % 96 // 16) + v3 = T.axis.spatial(32, rc_0_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + T.reads(PadInput[v0, v1, v2, v3]) + T.writes(PadInput_global_vtcm[v0, v1, v2, v3]) + PadInput_global_vtcm[v0, v1, v2, v3] = PadInput[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(4608): + with T.block("weight_global.vtcm"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 1536) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 1536 // 512) + v2 = T.axis.spatial(32, rc_0_0 * 16 + ax0_ax1_ax2_ax3_fused % 512 // 32) + v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_global_vtcm[v0, v1, v2, v3]) + weight_global_vtcm[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for n_1, h_1, w_1, co_0_1, rh_1, rw_1, rc_0_1, n_2, h_2, w_2, co_0_2 in T.grid( + 1, 1, 4, 1, 3, 3, 4, 1, 4, 1, 1 + ): + with T.block("conv2d_nhwc_o"): + v_n = T.axis.spatial(1, n_1 + n_2 + n_0) + v_h = T.axis.spatial(16, h_0 * 4 + h_1 * 4 + h_2) + v_w = T.axis.spatial(16, w_2 + w_0 * 4 + w_1) + v_co_o = T.axis.spatial(1, co_0_1 + co_0_2 + co_0_0) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) + v_rc_o = T.axis.reduce(8, rc_0_0 * 4 + rc_0_1) + T.reads( + PadInput_global_vtcm[ + v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 : v_rc_o * 4 + 4 + ], + weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 : v_rc_o * 4 + 4, 0:32], + ) + T.writes(conv2d_nhwc[v_n, v_h, v_w, 0:32]) + T.block_attr({"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for co_1 in T.serial(32): + with T.block("conv2d_nhwc_init"): + v_co_i_init = T.axis.spatial(32, co_1) + T.reads() + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co_i_init]) + conv2d_nhwc[v_n, v_h, v_w, v_co_i_init] = 0 + for co_1, rc_1 in T.grid(32, 4): + with T.block("conv2d_nhwc"): + v_co_i, v_rc_i = T.axis.remap("SR", [co_1, rc_1]) + T.reads( + conv2d_nhwc[v_n, v_h, v_w, v_co_i], + PadInput_global_vtcm[ + v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 + v_rc_i + ], + weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 + v_rc_i, v_co_i], + ) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + conv2d_nhwc[v_n, v_h, v_w, v_co_i] = conv2d_nhwc[ + v_n, v_h, v_w, v_co_i + ] + T.Cast( + "int32", + PadInput_global_vtcm[ + v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 + v_rc_i + ], + ) * T.Cast( + "int32", weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 + v_rc_i, v_co_i] + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 1, 4]), + ("SamplePerfectTile", [2, 4, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ] + + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + N=1, + H=16, + W=16, + CI=32, + CO=32, + kernel_size=3, + stride=1, + padding=1, + in_dtype="uint8", + out_dtype="int32", + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def conv2d_NCHWc_int8(I, O, H, W, kH, kW, stride, padding, dilation, out_dtype="int32", n_elems=32): + from tvm.topi.utils import get_const_tuple + from tvm.topi.nn.utils import get_pad_tuple + from tvm.topi.nn.pad import pad + + ic_bn = 32 + oc_bn = 32 + n_elems = 4 + dtype = "uint8" + + data = te.placeholder((1, I // ic_bn, H, W, ic_bn), name="data", dtype=dtype) + kernel = te.placeholder( + (O // oc_bn, I // ic_bn, kH, kW, ic_bn // n_elems, oc_bn, n_elems), dtype=dtype + ) + + # layout and out_layout are not used here, + # we keep them for debug convenience when dumping autotvm workload + HSTR, WSTR = stride if isinstance(stride, (tuple, list)) else (stride, stride) + dilation_h, dilation_w = ( + dilation if isinstance(dilation, (tuple, list)) else (dilation, dilation) + ) + + n, ic_chunk, ih, iw, ic_bn = get_const_tuple(data.shape) + in_channel = ic_chunk * ic_bn + oc_chunk, ic_chunk_group, kernel_height, kernel_width, _, oc_bn, _ = get_const_tuple( + kernel.shape + ) + + dilated_kernel_h = (kernel_height - 1) * dilation_h + 1 + dilated_kernel_w = (kernel_width - 1) * dilation_w + 1 + + pad_top, pad_left, pad_down, pad_right = get_pad_tuple( + padding, (dilated_kernel_h, dilated_kernel_w) + ) + HPAD = pad_top + pad_down + WPAD = pad_left + pad_right + + # output shape + out_height = (ih + HPAD - dilated_kernel_h) // HSTR + 1 + out_width = (iw + WPAD - dilated_kernel_w) // WSTR + 1 + oshape = (n, oc_chunk, out_height, out_width, oc_bn) + pad_before = (0, 0, pad_top, pad_left, 0) + pad_after = (0, 0, pad_down, pad_right, 0) + + # DOPAD + DOPAD = HPAD != 0 or WPAD != 0 + if DOPAD: + data_pad = pad(data, pad_before, pad_after, name="data_pad") + else: + data_pad = data + + kh = te.reduce_axis((0, kernel_height), name="kh") + kw = te.reduce_axis((0, kernel_width), name="kw") + + ic_outer = te.reduce_axis((0, in_channel // ic_bn), name="ic_outer") + ic_f_inner = te.reduce_axis((0, ic_bn // n_elems), name="ic_f_inner") + ic_s_inner = te.reduce_axis((0, n_elems), name="ic_s_inner") + + out = te.compute( + oshape, + lambda n, oc_chunk, oh, ow, oc_block: te.sum( + data_pad[ + n, + ic_outer, + oh * HSTR + kh * dilation_h, + ow * WSTR + kw * dilation_w, + ic_f_inner * n_elems + ic_s_inner, + ].astype(out_dtype) + * kernel[oc_chunk, ic_outer, kh, kw, ic_f_inner, oc_block, ic_s_inner].astype( + out_dtype + ), + axis=[kh, kw, ic_outer, ic_f_inner, ic_s_inner], + ), + ) + + return [data, kernel, out] + + +def test_conv2d_with_pipeline(): + + # from tvm.script import tir as T + @T.prim_func + def main( + data: T.Buffer[(1, 2, 56, 56, 32), "uint8"], + placeholder: T.Buffer[(2, 2, 3, 3, 8, 32, 4), "uint8"], + compute: T.Buffer[(1, 2, 56, 56, 32), "int32"], + ): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + with T.block("root"): + T.reads() + T.writes() + T.block_attr( + { + "meta_schedule.parallel": 160, + "meta_schedule.unroll_explicit": 0, + "meta_schedule.vectorize": 32, + } + ) + data_pad = T.alloc_buffer([1, 2, 58, 58, 32], dtype="uint8") + data_pad_global_vtcm = T.alloc_buffer( + [1, 2, 58, 58, 32], dtype="uint8", scope="global.vtcm" + ) + placeholder_global_vtcm = T.alloc_buffer( + [2, 2, 3, 3, 8, 32, 4], dtype="uint8", scope="global.vtcm" + ) + for i0, i1, i2, i3, i4 in T.grid(1, 2, 58, 58, 32): + with T.block("data_pad"): + v_i0, v_i1, v_i2, v_i3, v_i4 = T.axis.remap("SSSSS", [i0, i1, i2, i3, i4]) + T.reads(data[v_i0, v_i1, v_i2 - 1, v_i3 - 1, v_i4]) + T.writes(data_pad[v_i0, v_i1, v_i2, v_i3, v_i4]) + data_pad[v_i0, v_i1, v_i2, v_i3, v_i4] = T.if_then_else( + 1 <= v_i2 and v_i2 < 57 and 1 <= v_i3 and v_i3 < 57, + data[v_i0, v_i1, v_i2 - 1, v_i3 - 1, v_i4], + T.uint8(0), + dtype="uint8", + ) + for n_0, oc_chunk_0, oh_0, ow_0, oc_block_0_0 in T.grid(1, 1, 14, 14, 1): + for kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused in T.serial( + 2, + annotations={ + "software_pipeline_async_stages": [0], + "software_pipeline_order": [0, 1, 2], + "software_pipeline_stage": [0, 0, 1], + }, + ): + for ax0_ax1_ax2_ax3_ax4_fused in T.serial(1152): + with T.block("data_pad_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(2, ax0_ax1_ax2_ax3_ax4_fused // 576) + v2 = T.axis.spatial( + 58, oh_0 * 4 + ax0_ax1_ax2_ax3_ax4_fused % 576 // 96 + ) + v3 = T.axis.spatial(58, ow_0 * 4 + ax0_ax1_ax2_ax3_ax4_fused % 96 // 16) + v4 = T.axis.spatial( + 32, + kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * 16 + + ax0_ax1_ax2_ax3_ax4_fused % 16, + ) + T.reads(data_pad[v0, v1, v2, v3, v4]) + T.writes(data_pad_global_vtcm[v0, v1, v2, v3, v4]) + data_pad_global_vtcm[v0, v1, v2, v3, v4] = data_pad[v0, v1, v2, v3, v4] + for ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused in T.serial(18432): + with T.block("placeholder_global.vtcm"): + v0 = T.axis.spatial(2, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused // 9216) + v1 = T.axis.spatial(2, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 9216 // 4608) + v2 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 4608 // 1536) + v3 = T.axis.spatial(3, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 1536 // 512) + v4 = T.axis.spatial( + 8, + kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * 4 + + ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 512 // 128, + ) + v5 = T.axis.spatial(32, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 128 // 4) + v6 = T.axis.spatial(4, ax0_ax1_ax2_ax3_ax4_ax5_ax6_fused % 4) + T.reads(placeholder[v0, v1, v2, v3, v4, v5, v6]) + T.writes(placeholder_global_vtcm[v0, v1, v2, v3, v4, v5, v6]) + placeholder_global_vtcm[v0, v1, v2, v3, v4, v5, v6] = placeholder[ + v0, v1, v2, v3, v4, v5, v6 + ] + for ( + n_1, + oc_chunk_1, + oh_1, + ow_1, + oc_block_0_1, + kh_1, + kw_1, + ic_outer_1, + ic_f_inner_1, + ic_s_inner_0_1, + n_2, + oc_chunk_2, + oh_2, + ow_2, + oc_block_0_2, + ) in T.grid(1, 1, 4, 2, 1, 3, 3, 2, 4, 1, 1, 2, 1, 2, 1): + with T.block("compute_o"): + v_n = T.axis.spatial(1, n_2 + n_0 + n_1) + v_oc_chunk = T.axis.spatial( + 2, oc_chunk_0 * 2 + oc_chunk_1 * 2 + oc_chunk_2 + ) + v_oh = T.axis.spatial(56, oh_2 + oh_0 * 4 + oh_1) + v_ow = T.axis.spatial(56, ow_0 * 4 + ow_1 * 2 + ow_2) + v_oc_block_o = T.axis.spatial( + 1, oc_block_0_1 + oc_block_0_2 + oc_block_0_0 + ) + v_kh, v_kw, v_ic_outer = T.axis.remap("RRR", [kh_1, kw_1, ic_outer_1]) + v_ic_f_inner = T.axis.reduce( + 8, + kh_0_kw_0_ic_outer_0_ic_f_inner_0_ic_s_inner_0_0_fused * 4 + + ic_f_inner_1, + ) + v_ic_s_inner_o = T.axis.reduce(1, ic_s_inner_0_1) + T.reads( + data_pad_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * 4 : v_ic_f_inner * 4 + 4, + ], + placeholder_global_vtcm[ + v_oc_chunk, v_ic_outer, v_kh, v_kw, v_ic_f_inner, 0:32, 0:4 + ], + ) + T.writes(compute[v_n, v_oc_chunk, v_oh, v_ow, 0:32]) + T.block_attr( + {"meta_schedule.auto_tensorize": "dot_32x4_u8u8i32_vtcm_vrmpy"} + ) + with T.init(): + for oc_block_1 in T.serial(32): + with T.block("compute_init"): + v_oc_block_i_init = T.axis.spatial(32, oc_block_1) + T.reads() + T.writes( + compute[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] + ) + compute[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i_init] = 0 + for oc_block_1, ic_s_inner_1 in T.grid(32, 4): + with T.block("compute"): + v_oc_block_i, v_ic_s_inner_i = T.axis.remap( + "SR", [oc_block_1, ic_s_inner_1] + ) + T.reads( + compute[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i], + data_pad_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * 4 + v_ic_s_inner_i, + ], + placeholder_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + v_oc_block_i, + v_ic_s_inner_i, + ], + ) + T.writes(compute[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i]) + T.block_attr({"meta_schedule.tiling_structure": "SRSRS"}) + compute[v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i] = compute[ + v_n, v_oc_chunk, v_oh, v_ow, v_oc_block_i + ] + T.Cast( + "int32", + data_pad_global_vtcm[ + v_n, + v_ic_outer, + v_oh + v_kh, + v_ow + v_kw, + v_ic_f_inner * 4 + v_ic_s_inner_i, + ], + ) * T.Cast( + "int32", + placeholder_global_vtcm[ + v_oc_chunk, + v_ic_outer, + v_kh, + v_kw, + v_ic_f_inner, + v_oc_block_i, + v_ic_s_inner_i, + ], + ) + + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 1, 4]), + ("SamplePerfectTile", [2, 4, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [2, 4, 1]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SampleCategorical", 0), + ] + + strides = (1, 1) + padding = (1, 1) + dilation = (1, 1) + + mod = te.create_prim_func( + conv2d_NCHWc_int8(64, 64, 56, 56, 3, 3, strides, padding, dilation, out_dtype="int32") + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon -num-cores=10"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(use_software_pipeline=True), + ms.schedule_rule.ParallelizeVectorizeUnroll( + max_jobs_per_core=16, + max_vectorize_extent=32, + unroll_max_steps=[0, 16, 64, 512], + unroll_explicit=True, + ), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_conv_1x1(): + # fmt: off + @T.prim_func + def main(inputs: T.Buffer[(1, 16, 16, 32), "uint8"], weight: T.Buffer[(3, 3, 32, 32), "uint8"], conv2d_nhwc: T.Buffer[(1, 16, 16, 32), "int32"]): + # function attr dict + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + # body + # with T.block("root") + PadInput = T.alloc_buffer([1, 18, 18, 32], dtype="uint8") + PadInput_global_vtcm = T.alloc_buffer([1, 18, 18, 32], dtype="uint8", scope="global.vtcm") + weight_global_vtcm = T.alloc_buffer([3, 3, 32, 32], dtype="uint8", scope="global.vtcm") + for i0, i1, i2, i3 in T.grid(1, 18, 18, 32): + with T.block("PadInput"): + v_i0, v_i1, v_i2, v_i3 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.reads(inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3]) + T.writes(PadInput[v_i0, v_i1, v_i2, v_i3]) + PadInput[v_i0, v_i1, v_i2, v_i3] = T.if_then_else(1 <= v_i1 and v_i1 < 17 and 1 <= v_i2 and v_i2 < 17, inputs[v_i0, v_i1 - 1, v_i2 - 1, v_i3], T.uint8(0), dtype="uint8") + for n_0, h_0, w_0, co_0_0, rh_0, rw_0, rc_0_0 in T.grid(1, 4, 4, 1, 1, 1, 2): + for ax0_ax1_ax2_ax3_fused in T.serial(576): + with T.block("PadInput_global.vtcm"): + v0 = T.axis.spatial(1, 0) + v1 = T.axis.spatial(18, h_0 * 4 + ax0_ax1_ax2_ax3_fused // 96) + v2 = T.axis.spatial(18, w_0 * 4 + ax0_ax1_ax2_ax3_fused % 96 // 16) + v3 = T.axis.spatial(32, rc_0_0 * 16 + ax0_ax1_ax2_ax3_fused % 16) + T.reads(PadInput[v0, v1, v2, v3]) + T.writes(PadInput_global_vtcm[v0, v1, v2, v3]) + PadInput_global_vtcm[v0, v1, v2, v3] = PadInput[v0, v1, v2, v3] + for ax0_ax1_ax2_ax3_fused in T.serial(4608): + with T.block("weight_global.vtcm"): + v0 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused // 1536) + v1 = T.axis.spatial(3, ax0_ax1_ax2_ax3_fused % 1536 // 512) + v2 = T.axis.spatial(32, rc_0_0 * 16 + ax0_ax1_ax2_ax3_fused % 512 // 32) + v3 = T.axis.spatial(32, ax0_ax1_ax2_ax3_fused % 32) + T.reads(weight[v0, v1, v2, v3]) + T.writes(weight_global_vtcm[v0, v1, v2, v3]) + weight_global_vtcm[v0, v1, v2, v3] = weight[v0, v1, v2, v3] + for n_1, h_1, w_1, co_0_1, rh_1, rw_1, rc_0_1, n_2, h_2, w_2, co_0_2 in T.grid(1, 1, 4, 1, 3, 3, 4, 1, 4, 1, 1): + with T.block("conv2d_nhwc_o"): + v_n = T.axis.spatial(1, n_1 + n_2 + n_0) + v_h = T.axis.spatial(16, h_0 * 4 + h_1 * 4 + h_2) + v_w = T.axis.spatial(16, w_2 + w_0 * 4 + w_1) + v_co_o = T.axis.spatial(1, co_0_1 + co_0_2 + co_0_0) + v_rh = T.axis.reduce(3, rh_0 * 3 + rh_1) + v_rw = T.axis.reduce(3, rw_0 * 3 + rw_1) + v_rc_o = T.axis.reduce(8, rc_0_0 * 4 + rc_0_1) + T.reads(PadInput_global_vtcm[v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 : v_rc_o * 4 + 4], weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 : v_rc_o * 4 + 4, 0 : 32]) + T.writes(conv2d_nhwc[v_n, v_h, v_w, 0 : 32]) + T.block_attr({"meta_schedule.auto_tensorize":"dot_32x4_u8u8i32_vtcm_vrmpy"}) + with T.init(): + for co_1 in T.serial(32): + with T.block("conv2d_nhwc_init"): + v_co_i_init = T.axis.spatial(32, co_1) + T.reads() + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co_i_init]) + conv2d_nhwc[v_n, v_h, v_w, v_co_i_init] = 0 + for co_1, rc_1 in T.grid(32, 4): + with T.block("conv2d_nhwc"): + v_co_i, v_rc_i = T.axis.remap("SR", [co_1, rc_1]) + T.reads(conv2d_nhwc[v_n, v_h, v_w, v_co_i], PadInput_global_vtcm[v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 + v_rc_i], weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 + v_rc_i, v_co_i]) + T.writes(conv2d_nhwc[v_n, v_h, v_w, v_co_i]) + T.block_attr({"meta_schedule.tiling_structure":"SRSRS"}) + conv2d_nhwc[v_n, v_h, v_w, v_co_i] = conv2d_nhwc[v_n, v_h, v_w, v_co_i] + T.Cast("int32", PadInput_global_vtcm[v_n, v_h + v_rh, v_w + v_rw, v_rc_o * 4 + v_rc_i]) * T.Cast("int32", weight_global_vtcm[v_rh, v_rw, v_rc_o * 4 + v_rc_i, v_co_i]) + # fmt: on + decision_0 = [ + ("SamplePerfectTile", [4, 1, 1]), + ("SamplePerfectTile", [2, 1, 4]), + ("SamplePerfectTile", [2, 4, 1]), + ("SamplePerfectTile", [2, 2, 2]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ("SamplePerfectTile", [2, 4]), + ] + + mod = te.create_prim_func( + te_workload.conv2d_nhwc( + N=1, + H=16, + W=16, + CI=32, + CO=32, + kernel_size=3, + stride=1, + padding=1, + in_dtype="uint8", + out_dtype="int32", + ) + ) + + actual_design_space = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[ + multi_level_tiling_hexagon(), + ] + + get_rules(kind="hexagon", types=ms.schedule_rule.AutoInline), + ) + + check_sketches( + mod, + sketches=actual_design_space, + expected_mods=[main], + expected_decisions=[decision_0], + ) + + +def test_matmul_relu_non_tensorizable(): + # expected to do nothing on non-tensorizable workloads + mod = te.create_prim_func( + te_workload.matmul_relu( # dtype doesn't match tensor intrin + n=128, + m=128, + k=128, + ) + ) + (sch,) = generate_design_space( + kind="hexagon", + mod=mod, + target=tvm.target.Target("hexagon"), + types=None, + sch_rules=[multi_level_tiling_hexagon(write_reuse_scope="global")] + + get_rules("hexagon", ms.schedule_rule.AutoInline), + ) + tvm.ir.assert_structural_equal(mod, sch.mod["main"]) + + +if __name__ == "__main__": + test_dense_base() + test_dense_with_fallback() + test_dense_global() + test_dense_with_pipeline() + test_padded_dense() + test_conv2d() + test_conv_1x1() + test_conv2d_with_pipeline() + test_matmul_relu_non_tensorizable() From 375555d45a6fa27e35f0bb830cfa9ea269dde81e Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Wed, 18 Jan 2023 14:51:16 -0800 Subject: [PATCH 2/4] Remove unnecessary statement visitor. --- include/tvm/tir/data_type_rewriter.h | 1 - src/tir/ir/data_type_rewriter.cc | 17 ----------------- 2 files changed, 18 deletions(-) diff --git a/include/tvm/tir/data_type_rewriter.h b/include/tvm/tir/data_type_rewriter.h index 9cd6eae02028..5f72f75ede41 100644 --- a/include/tvm/tir/data_type_rewriter.h +++ b/include/tvm/tir/data_type_rewriter.h @@ -108,7 +108,6 @@ class IndexDataTypeRewriter : public DataTypeLegalizer { Stmt VisitStmt_(const IfThenElseNode* op) override; Stmt VisitStmt_(const DeclBufferNode* op) override; Stmt VisitStmt_(const AllocateNode* op) override; - Stmt VisitStmt_(const AllocateConstNode* op) override; PrimExpr VisitExpr_(const EQNode* op) override; PrimExpr VisitExpr_(const NENode* op) override; PrimExpr VisitExpr_(const LTNode* op) override; diff --git a/src/tir/ir/data_type_rewriter.cc b/src/tir/ir/data_type_rewriter.cc index ca589957294b..c17a52e403f7 100644 --- a/src/tir/ir/data_type_rewriter.cc +++ b/src/tir/ir/data_type_rewriter.cc @@ -244,23 +244,6 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateNode* op) { } } -Stmt IndexDataTypeRewriter::VisitStmt_(const AllocateConstNode* op) { - bool is_enabled = is_enabled_; - is_enabled_ = true; - auto new_extents = op->extents.Map([this](const PrimExpr& e) { return this->VisitExpr(e); }); - is_enabled_ = is_enabled; - auto new_body = this->VisitStmt(op->body); - if (!new_extents.same_as(op->extents) || !new_body.same_as(op->body)) { - AllocateConst new_allocate = GetRef(op); - auto* n = new_allocate.CopyOnWrite(); - n->extents = std::move(new_extents); - n->body = std::move(new_body); - return std::move(new_allocate); - } else { - return GetRef(op); - } -} - Stmt IndexDataTypeRewriter::VisitStmt_(const DeclBufferNode* op) { Buffer new_buffer = VisitBuffer(op->buffer); DeclBuffer decl_buffer = Downcast(StmtExprMutator::VisitStmt_(op)); From 81433212a346b2143d4cc06ee6b9335d31aa7375 Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Fri, 17 Feb 2023 10:50:38 -0800 Subject: [PATCH 3/4] Add pr updates around comments. --- .../multi_level_tiling_hexagon.cc | 6 +++--- .../contrib/test_hexagon/test_conv2d_async.py | 20 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc index c8636dc5b429..690a83386b1a 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc @@ -60,7 +60,7 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) if (!use_software_pipeline) { return {state}; } - // The current config is not suitable for software pipelining. + // The current config is not suitable for software pipelining if the r_indices_ (reduction indicies) are less than 2. if (r_indices_.size() < 2) { return {state}; } @@ -91,9 +91,9 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) Array software_pipeline_order; Array software_pipeline_async_stages; if (cache_read_count == 2) { - software_pipeline_stage = Array{0, 0, 1}; + software_pipeline_stage = Array{0, 0, 1}; // The pipeline merges the first 2 stages into one. software_pipeline_order = Array{0, 1, 2}; - software_pipeline_async_stages = Array{0}; + software_pipeline_async_stages = Array{0}; // The 0th stage is set as async. } else { software_pipeline_stage = Array{0, 1}; software_pipeline_order = Array{0, 1}; diff --git a/tests/python/contrib/test_hexagon/test_conv2d_async.py b/tests/python/contrib/test_hexagon/test_conv2d_async.py index 5e92a49f8613..10376c35eabe 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_async.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_async.py @@ -85,19 +85,17 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): executor = relay.backend.Executor("graph", {"link-params": True}) mod = mod.with_attr("executor", executor) - use_async = True - if use_async: - config = { - "tir.use_async_copy": True, - "tir.merge_async_commit_queue_scope": False, - } + config = { + "tir.use_async_copy": True, + "tir.merge_async_commit_queue_scope": False, + } - ctx = tvm.transform.PassContext( - opt_level=3, - config=config, - ) - sch_rules = sch_rules_async + ctx = tvm.transform.PassContext( + opt_level=3, + config=config, + ) + sch_rules = sch_rules_async with tempfile.TemporaryDirectory() as work_dir: database = ms.relay_integration.tune_relay( From 7abbd6d95b04f7796bdbf14be7b0840075ad3606 Mon Sep 17 00:00:00 2001 From: Noah Verke Date: Fri, 17 Feb 2023 12:30:40 -0800 Subject: [PATCH 4/4] lint --- .../schedule_rule/multi_level_tiling_hexagon.cc | 8 +++++--- tests/python/contrib/test_hexagon/test_conv2d_async.py | 1 - 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc index 690a83386b1a..f47a4b151239 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling_hexagon.cc @@ -60,7 +60,8 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) if (!use_software_pipeline) { return {state}; } - // The current config is not suitable for software pipelining if the r_indices_ (reduction indicies) are less than 2. + // The current config is not suitable for software pipelining if the r_indices_ (reduction + // indicies) are less than 2. if (r_indices_.size() < 2) { return {state}; } @@ -91,9 +92,10 @@ std::vector MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) Array software_pipeline_order; Array software_pipeline_async_stages; if (cache_read_count == 2) { - software_pipeline_stage = Array{0, 0, 1}; // The pipeline merges the first 2 stages into one. + software_pipeline_stage = + Array{0, 0, 1}; // The pipeline merges the first 2 stages into one. software_pipeline_order = Array{0, 1, 2}; - software_pipeline_async_stages = Array{0}; // The 0th stage is set as async. + software_pipeline_async_stages = Array{0}; // The 0th stage is set as async. } else { software_pipeline_stage = Array{0, 1}; software_pipeline_order = Array{0, 1}; diff --git a/tests/python/contrib/test_hexagon/test_conv2d_async.py b/tests/python/contrib/test_hexagon/test_conv2d_async.py index 10376c35eabe..4210dfe8e7e4 100644 --- a/tests/python/contrib/test_hexagon/test_conv2d_async.py +++ b/tests/python/contrib/test_hexagon/test_conv2d_async.py @@ -85,7 +85,6 @@ def tune_vrmpy_auto_tensorize(mod, params, hexagon_launcher): executor = relay.backend.Executor("graph", {"link-params": True}) mod = mod.with_attr("executor", executor) - config = { "tir.use_async_copy": True, "tir.merge_async_commit_queue_scope": False,