-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[MetaSchedule][Hexagon] Add MultiLevelTilingHexagon to schedule asyn… #13721
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -210,6 +210,31 @@ class ScheduleRule : public runtime::ObjectRef { | |
| Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read, | ||
| Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline); | ||
|
|
||
| /*! | ||
| * \brief Extension of MultiLevelTiling for auto-tensorization with multiple groups of candidate | ||
| * tensor core intrinsics | ||
| * \param intrin_groups A list of groups of tensor core intrinsics. The map should contain key | ||
| * "compute" which represents the tensor intrin for computation. The value of the map should be | ||
| * names of tensor intrinsics, must be registered via | ||
| * TensorIntrin.register(...) beforehand | ||
| * \param structure The tiling structure. Recommended: | ||
| * - 'SRSRS' on Hexagon | ||
| * \param tile_binds For each level of tiles, which thread axis it is bound to. These are not | ||
| * supported on hexagon. | ||
| * \param max_innermost_factor The maximum size of the innermost factor. NullOpt means no limit | ||
| * \param vector_load_lens The length of vector lane in vectorized cooperative fetching. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Confused as to why we have both the vector load length and the max innermost factor. These seem redundant. Aren't we always going to vectorize over the innermost loop? And, if vectorization is enabled, won't we use
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These have two separate uses, vector_load_lens is used for vector loads outside of the tiling loops where as max_innermost_factor is for loop tiling. As for the naming and usage I am not 100% sure.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Like I said, just an observation given you are inheriting this API. |
||
| * NullOpt means disable vectorization | ||
| * \param reuse_read Data reuse configuration for reading. NullOpt means no reuse. | ||
| * \param reuse_write Data reuse configuration for writing. NullOpt means no reuse. | ||
| * \param use_software_pipeline Whether use the software pipeline. | ||
| * \return The schedule rule created | ||
| */ | ||
| TVM_DLL static ScheduleRule MultiLevelTilingHexagon( | ||
| Array<Map<String, String>> intrin_groups, String structure, | ||
| Optional<Array<String>> tile_binds, Optional<Integer> max_innermost_factor, | ||
| Optional<Array<Integer>> vector_load_lens, Optional<Map<String, ObjectRef>> reuse_read, | ||
| Optional<Map<String, ObjectRef>> reuse_write, bool use_software_pipeline); | ||
|
|
||
| /*! | ||
| * \brief Extension of MultiLevelTiling for backends with wide vectors. | ||
| * The loop over the innermost spatial axis of the output buffer is always vectorized with the | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -197,6 +197,60 @@ def __init__( | |
| ) | ||
|
|
||
|
|
||
| @register_object("meta_schedule.MultiLevelTilingHexagon") | ||
| class MultiLevelTilingHexagon(ScheduleRule): | ||
| """Extension of MultiLevelTiling for auto-tensorizing with multiple groups of candidate hexagon | ||
| intrinsics. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| intrin_groups : List[Mapping[str, str]] | ||
| A list of groups of tensor core intrinsics. The map should contain key | ||
| "compute" which represents the tensor intrin for computation. The value of the map should be | ||
| names of tensor intrinsics, must be registered via | ||
| TensorIntrin.register(...) beforehand | ||
| structure : str | ||
| The tiling structure. Recommended: | ||
| - 'SRSRS' on Hexagon | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You might map
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see above |
||
| tile_bind : Optional[List[str]] | ||
| For each level of tiles, which thread axis it is bound to. Not supported on Hexagon. | ||
| max_innermost_factor : Optional[int] | ||
| The maximum size of the innermost factor. None means no limit | ||
| vector_load_lens : Optional[List[int]] | ||
| The length of vector lane in vectorized cooperative fetching. | ||
| None means disable vectorization | ||
| reuse_read : Optional[ReuseType] | ||
| Data reuse configuration for reading. None means no reuse. | ||
| reuse_write : Optional[ReuseType] | ||
| Data reuse configuration for writing. None means no reuse. | ||
| use_software_pipeline : bool | ||
| Whether to use the software pipeline. | ||
| """ | ||
|
|
||
| def __init__( | ||
| self, | ||
| intrin_groups: List[Mapping[str, str]], | ||
| structure: str, | ||
| tile_binds: Optional[List[str]] = None, | ||
| max_innermost_factor: Optional[int] = None, | ||
| vector_load_lens: Optional[List[int]] = None, | ||
| reuse_read: Optional[ReuseType] = None, | ||
| reuse_write: Optional[ReuseType] = None, | ||
| use_software_pipeline: bool = False, | ||
| ) -> None: | ||
| self.__init_handle_by_constructor__( | ||
| _ffi_api.ScheduleRuleMultiLevelTilingHexagon, # type: ignore # pylint: disable=no-member | ||
| intrin_groups, | ||
| structure, | ||
| tile_binds, | ||
| max_innermost_factor, | ||
| vector_load_lens, | ||
| reuse_read.as_dict() if reuse_read is not None else None, | ||
| reuse_write.as_dict() if reuse_write is not None else None, | ||
| use_software_pipeline, | ||
| ) | ||
|
|
||
|
|
||
| @register_object("meta_schedule.MultiLevelTilingWideVector") | ||
| class MultiLevelTilingWideVector(ScheduleRule): | ||
| """Extension of MultiLevelTiling for backends with wide vectors. The loop over the innermost | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,6 +16,7 @@ | |
| # under the License. | ||
| # pylint: disable=invalid-name,missing-function-docstring | ||
| """Intrinsics for Hexagon tensorization.""" | ||
|
|
||
| from tvm.script import tir as T | ||
| from .. import TensorIntrin | ||
|
|
||
|
|
@@ -68,12 +69,12 @@ def sync_dma_load_impl(a: T.handle, c: T.handle) -> None: | |
| return sync_dma_load_desc, sync_dma_load_impl | ||
|
|
||
|
|
||
| def generate_dot_product_32x4_u8u8i32(mem_scope="global"): | ||
| def generate_dot_product_32x4_u8u8i32(read_mem_scope="global", write_mem_scope="global"): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are the defaults here "global" instead of "global.vtcm"?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because typically these are used without vtcm |
||
| @T.prim_func | ||
| def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:4], B[0:32, 0:4]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -85,9 +86,9 @@ def dot_product_32x4_u8u8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None | |
|
|
||
| @T.prim_func | ||
| def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:4], B[0:32, 0:4]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -110,12 +111,12 @@ def dot_product_32x4_u8u8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non | |
| return dot_product_32x4_u8u8i32_desc, dot_product_32x4_u8u8i32_vrmpy | ||
|
|
||
|
|
||
| def generate_dot_product_32x4_u8i8i32(mem_scope="global"): | ||
| def generate_dot_product_32x4_u8i8i32(read_mem_scope="global", write_mem_scope="global"): | ||
| @T.prim_func | ||
| def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:4], B[0:32, 0:4]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -127,9 +128,9 @@ def dot_product_32x4_u8i8i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None | |
|
|
||
| @T.prim_func | ||
| def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (4,), "uint8", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 4), "int8", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:4], B[0:32, 0:4]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -152,12 +153,12 @@ def dot_product_32x4_u8i8i32_vrmpy(a: T.handle, b: T.handle, c: T.handle) -> Non | |
| return dot_product_32x4_u8i8i32_desc, dot_product_32x4_u8i8i32_vrmpy | ||
|
|
||
|
|
||
| def generate_dot_product_32x2_i16i16i32(mem_scope="global"): | ||
| def generate_dot_product_32x2_i16i16i32(read_mem_scope="global", write_mem_scope="global"): | ||
| @T.prim_func | ||
| def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:2], B[0:32, 0:2]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -169,9 +170,9 @@ def dot_product_32x2_i16i16i32_desc(a: T.handle, b: T.handle, c: T.handle) -> No | |
|
|
||
| @T.prim_func | ||
| def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> None: | ||
| A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=mem_scope) | ||
| B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=mem_scope) | ||
| A = T.match_buffer(a, (2,), "int16", offset_factor=1, scope=read_mem_scope) | ||
| B = T.match_buffer(b, (32, 2), "int16", offset_factor=1, scope=read_mem_scope) | ||
| C = T.match_buffer(c, (32,), "int32", offset_factor=1, scope=write_mem_scope) | ||
| with T.block("root"): | ||
| T.reads(C[0:32], A[0:2], B[0:32, 0:2]) | ||
| T.writes(C[0:32]) | ||
|
|
@@ -207,10 +208,22 @@ def dot_product_32x2_i16i16i32_vdmpy(a: T.handle, b: T.handle, c: T.handle) -> N | |
| TensorIntrin.register(VDMPY_i16i16i32_INTRIN, *generate_dot_product_32x2_i16i16i32()) | ||
|
|
||
| VRMPY_u8u8i32_VTCM_INTRIN = "dot_32x4_u8u8i32_vtcm_vrmpy" | ||
| TensorIntrin.register(VRMPY_u8u8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8u8i32("global.vtcm")) | ||
| TensorIntrin.register( | ||
| VRMPY_u8u8i32_VTCM_INTRIN, | ||
| *generate_dot_product_32x4_u8u8i32("global.vtcm", "global.vtcm"), | ||
| ) | ||
|
|
||
| VRMPY_u8u8i32_VTCM_READS_INTRIN = "dot_32x4_u8u8i32_vtcm_reads_vrmpy" | ||
| TensorIntrin.register( | ||
| VRMPY_u8u8i32_VTCM_READS_INTRIN, | ||
| *generate_dot_product_32x4_u8u8i32("global.vtcm", "global"), | ||
| ) | ||
|
|
||
| VRMPY_u8i8i32_VTCM_INTRIN = "dot_32x4_u8i8i32_vtcm_vrmpy" | ||
| TensorIntrin.register(VRMPY_u8i8i32_VTCM_INTRIN, *generate_dot_product_32x4_u8i8i32("global.vtcm")) | ||
| TensorIntrin.register( | ||
| VRMPY_u8i8i32_VTCM_INTRIN, | ||
| *generate_dot_product_32x4_u8i8i32("global.vtcm", "global.vtcm"), | ||
| ) | ||
|
|
||
| DMA_READ_128_u8 = "dma_read_128_u8" | ||
| TensorIntrin.register(DMA_READ_128_u8, *generate_dma_load_intrin(128, "uint8")) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,147 @@ | ||
| /* | ||
| * Licensed to the Apache Software Foundation (ASF) under one | ||
| * or more contributor license agreements. See the NOTICE file | ||
| * distributed with this work for additional information | ||
| * regarding copyright ownership. The ASF licenses this file | ||
| * to you under the Apache License, Version 2.0 (the | ||
| * "License"); you may not use this file except in compliance | ||
| * with the License. You may obtain a copy of the License at | ||
| * | ||
| * http://www.apache.org/licenses/LICENSE-2.0 | ||
| * | ||
| * Unless required by applicable law or agreed to in writing, | ||
| * software distributed under the License is distributed on an | ||
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
| * KIND, either express or implied. See the License for the | ||
| * specific language governing permissions and limitations | ||
| * under the License. | ||
| */ | ||
|
|
||
| #include "../../tir/schedule/analysis.h" | ||
| #include "../../tir/schedule/transform.h" | ||
| #include "../utils.h" | ||
| #include "multi_level_tiling_with_intrin.h" | ||
|
|
||
| namespace tvm { | ||
| namespace meta_schedule { | ||
|
|
||
| using tir::BlockRV; | ||
| using tir::LoopRV; | ||
| using tir::Schedule; | ||
|
|
||
| class MultiLevelTilingHexagonNode : public MultiLevelTilingWithIntrinNode { | ||
| private: | ||
| // Subrule: Add software pipeline | ||
| inline std::vector<State> AddSoftwarePipeline(State state) const; | ||
|
|
||
| // Override ApplySubRules to apply tensorization-specific sub-rules | ||
| std::vector<State> ApplySubRules(std::vector<State> states) final; | ||
|
|
||
| // Inherited from ScheduleRuleNode | ||
| ScheduleRule Clone() const override { | ||
| ObjectPtr<MultiLevelTilingHexagonNode> n = make_object<MultiLevelTilingHexagonNode>(*this); | ||
| return ScheduleRule(n); | ||
| } | ||
|
|
||
| public: | ||
| /*! \brief Whether to use software pipeline */ | ||
| bool use_software_pipeline = false; | ||
| static constexpr const char* _type_key = "meta_schedule.MultiLevelTilingHexagon"; | ||
| TVM_DECLARE_FINAL_OBJECT_INFO(MultiLevelTilingHexagonNode, MultiLevelTilingNode); | ||
| }; | ||
|
|
||
| std::vector<State> MultiLevelTilingHexagonNode::ApplySubRules(std::vector<State> states) { | ||
| states = MultiLevelTilingWithIntrinNode::ApplySubRules(states); | ||
| states = SubRule(std::move(states), [&](State state) { return AddSoftwarePipeline(state); }); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I follow. If we don't inherit we would have to copy all of the logic needed over from MLT with intrin. But we don't want to add software pipelines to all usages of MLT with intrin as this will try to optimize for hexagon.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My comment may be naïve. My read of the code was that the only connection between |
||
| return states; | ||
| } | ||
|
|
||
| std::vector<State> MultiLevelTilingHexagonNode::AddSoftwarePipeline(State state) const { | ||
| if (!use_software_pipeline) { | ||
| return {state}; | ||
| } | ||
| // The current config is not suitable for software pipelining if the r_indices_ (reduction | ||
| // indicies) are less than 2. | ||
| if (r_indices_.size() < 2) { | ||
| return {state}; | ||
| } | ||
|
|
||
| Schedule& sch = state->sch; | ||
| // Check reduction length after blockize. | ||
| int64_t reduction_length = 1; | ||
| for (int r_index : r_indices_) { | ||
| const Array<LoopRV>& tiles = state->tiles[r_index]; | ||
| for (const LoopRV& tile : tiles) { | ||
| const auto* extent = sch->Get(tile)->extent.as<IntImmNode>(); | ||
| ICHECK(extent != nullptr) << "Dynamic extent is not supported."; | ||
| reduction_length *= extent->value; | ||
| } | ||
| } | ||
| if (reduction_length <= 1) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious use of
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should not be possible but since extents with value 0 can happen then the length could end up being 0. |
||
| return {state}; | ||
| } | ||
|
|
||
| // Return if there are more less than 1 or more than 2 cache_reads. | ||
| size_t cache_read_count = state->read_reuse.size(); | ||
| if (cache_read_count > 2 || cache_read_count == 0) { | ||
| return {state}; | ||
| } | ||
|
|
||
| // Add annotations for software pipelining at the loop right above the cache read stages. | ||
| Array<Integer> software_pipeline_stage; | ||
| Array<Integer> software_pipeline_order; | ||
| Array<Integer> software_pipeline_async_stages; | ||
| if (cache_read_count == 2) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This looks correct but this notation is difficult to read for many folks. Some comments might help.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried to explain as best I could! |
||
| software_pipeline_stage = | ||
| Array<Integer>{0, 0, 1}; // The pipeline merges the first 2 stages into one. | ||
| software_pipeline_order = Array<Integer>{0, 1, 2}; | ||
| software_pipeline_async_stages = Array<Integer>{0}; // The 0th stage is set as async. | ||
| } else { | ||
| software_pipeline_stage = Array<Integer>{0, 1}; | ||
| software_pipeline_order = Array<Integer>{0, 1}; | ||
| software_pipeline_async_stages = Array<Integer>{0}; | ||
| } | ||
|
|
||
| tir::BlockRV cache_read_block = state->read_reuse.begin()->second; | ||
| Array<LoopRV> cache_read_loops = sch->GetLoops(cache_read_block); | ||
| Array<LoopRV> reduction_loops; | ||
| for (size_t i = 0; i < cache_read_loops.size() - 1; ++i) { | ||
| if (tir::GetLoopIterType(sch->GetSRef(cache_read_loops[i])) != tir::IterVarType::kDataPar) { | ||
| reduction_loops.push_back(cache_read_loops[i]); | ||
| } else if (reduction_loops.size() > 0 && | ||
| sch->Get(cache_read_loops[i])->extent.as<IntImmNode>()->value == 1) { | ||
| reduction_loops.push_back(cache_read_loops[i]); | ||
| } | ||
| } | ||
| auto fused = sch->Fuse(reduction_loops); | ||
|
|
||
| sch->Annotate(fused, tir::attr::software_pipeline_stage, software_pipeline_stage); | ||
| sch->Annotate(fused, tir::attr::software_pipeline_order, software_pipeline_order); | ||
| sch->Annotate(fused, tir::attr::software_pipeline_async_stages, software_pipeline_async_stages); | ||
|
|
||
| // TODO(nverke): Add support for nested async pipelines. | ||
| // TODO(nverke): Add support for async cache writes. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the lack of cache write support here due to the issue in the InjectSWPipeline pass where there is no "wait" on cache write stage?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No its just from limiting complexity of this to start. |
||
|
|
||
| return {state}; | ||
| } | ||
|
|
||
| ScheduleRule ScheduleRule::MultiLevelTilingHexagon( | ||
| Array<Map<String, String>> intrin_groups, String structure, Optional<Array<String>> tile_binds, | ||
| Optional<Integer> max_innermost_factor, Optional<Array<Integer>> vector_load_lens, | ||
| Optional<Map<String, ObjectRef>> reuse_read, Optional<Map<String, ObjectRef>> reuse_write, | ||
| bool use_software_pipeline) { | ||
| CHECK(!tile_binds.defined()) << "Tile binds cannot be used on hexagon."; | ||
| auto node = MultiLevelTilingInitCommon<MultiLevelTilingHexagonNode>( | ||
| structure, tile_binds, max_innermost_factor, vector_load_lens, reuse_read, reuse_write); | ||
|
|
||
| node->intrin_name = intrin_groups[0]["compute"]; | ||
| node->use_software_pipeline = use_software_pipeline; | ||
| return ScheduleRule(node); | ||
| } | ||
|
|
||
| TVM_REGISTER_NODE_TYPE(MultiLevelTilingHexagonNode); | ||
| TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleMultiLevelTilingHexagon") | ||
| .set_body_typed(ScheduleRule::MultiLevelTilingHexagon); | ||
|
|
||
| } // namespace meta_schedule | ||
| } // namespace tvm | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might map
SRSRSto the layoutNCHWcin the commentUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be applicable outside of schedules that have input of NCHWc but are you suggesting adding that to help better map this to something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It just took me a minute to make the connection between the Hexagon layout (NCHWc) and the tiling structure (SRSRS) which led me to suggest a comment to clarify. The key thing for my understanding when I originally reviewed the code was to connect R=Reduction to the C=Channel axes.