From f85f0b6ce6be917a4f5ab2807f555368b8723ead Mon Sep 17 00:00:00 2001 From: Anirudh Sundar Date: Tue, 25 Apr 2023 15:03:15 +0530 Subject: [PATCH] [TIR] [Hexagon] Add get_vtcm_allocation_sizes with lowering This patch adds an utility function for getting the VTCM sizes allocated in an IRModule. In order to do that, we've exposed the list of lowering passes to python and we've refactored the PostprocVerifyVTCMLimit to be computed for whole module and the same list of lowering passes --- include/tvm/tir/analysis.h | 16 ++++++ python/tvm/tir/analysis/analysis.py | 13 +++++ python/tvm/topi/hexagon/utils.py | 52 +++++++++++++++++-- .../postproc/verify_vtcm_limit.cc | 44 +++------------- .../analysis/calculate_allocated_memory.cc | 33 ++++++++++++ 5 files changed, 119 insertions(+), 39 deletions(-) diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 3b5959e7816d..f4684231f008 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -184,6 +184,22 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); */ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +/** + * @brief Utility function to get the list of lowering passes to be applied to calculate the + * compacted VTCM allocation size + * + * @return returns list of passes + */ +TVM_DLL Array GetVTCMCompactionPasses(); + +/*! + * \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule + * \param mod The module to be checked + * \param limit The limit to check. + * \return true if the VTCM usage is within the provided limit. + */ +TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer limit); + /*! * \brief Verifies that the VTCM usage of the given prim_func is within the provided limit. * \param func The function to be checked. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 387ea0498015..1a5f8b978168 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List, Union +import tvm from tvm import Object from tvm.ir import IRModule from tvm.tir.expr import Var @@ -384,3 +385,15 @@ def find_anchor_block(mod: IRModule) -> Block: The anchor block if found, None otherwise. """ return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member + + +def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: + """Utility function to get the list of lowering passes to be applied to calculate thecompacted + VTCM allocation size + + Returns + ------- + result : List[tvm.transform.Pass] + returns list of passes + """ + return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint: disable=no-member diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 3148360cc2d1..f017aaebbdeb 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -21,9 +21,11 @@ """Common hexagon specific utilities""" import math import struct -from typing import Tuple -from tvm import te -from tvm.tir import IndexMap +from typing import Dict, Tuple, Union + +import tvm +from tvm import IRModule, te +from tvm.tir import IndexMap, PrimFunc def n11c_1024c_2d(n, h, w, c): @@ -354,3 +356,47 @@ def within_range(val, dtype): def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + + +def get_vtcm_allocation_sizes( + func_or_mod: Union[PrimFunc, IRModule], compacted=True +) -> Dict[str, int]: + """Calculate and return the vtcm allocation sizes for all the functions in + the IRModule or just the vtcm size if a single PrimFunc is passed + + Parameters + ---------- + func_or_mod : Union[PrimFunc, IRModule] + PrimFunc or IRModule for which VTCM allocation size is to be calculated + compacted : + Whether to calculate the sizes after applying VTCM lowering passes for + buffer compaction. This helps return the VTCM size that would get + allocated after lowering + + Returns + ------- + result : Dict[str, int] + A dict with function names as keys and vtcm allocated + inside that function as values + + """ + if not isinstance(func_or_mod, (PrimFunc, IRModule)): + raise TypeError( + f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}" + ) + if isinstance(func_or_mod, tvm.tir.PrimFunc): + mod = tvm.IRModule.from_expr(func_or_mod) + else: + mod = func_or_mod + if compacted: + passes = tvm.tir.analysis.get_vtcm_compaction_passes() + mod = tvm.transform.Sequential(list(passes))(mod) + + result = {} + all_sizes = tvm.tir.analysis.calculate_allocated_bytes(mod) + for func_name, sizes in all_sizes.items(): + if "global.vtcm" in sizes: + result[func_name] = sizes["global.vtcm"] + else: + result[func_name] = 0 + return result diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index 46bc7486e1df..4de975089653 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -36,48 +36,20 @@ class VerifyVTCMLimitNode : public PostprocNode { } bool Verify(const IRModule& mod) const { - for (const auto& kv : mod->functions) { - if (auto prim_func = kv.second.as()) { - if (!tir::VerifyVTCMLimit(prim_func.value(), vtcm_capacity)) { - return false; - } - } + if (!tir::VerifyVTCMLimit(mod, vtcm_capacity)) { + return false; } return true; } bool Apply(const tir::Schedule& sch) final { IRModule mod = sch->mod(); - for (const auto& kv : mod->functions) { - const GlobalVar& g_var = kv.first; - const BaseFunc& base_func = kv.second; - if (const auto* prim_func = base_func.as()) { - IRModule lowered{nullptr}; - try { - auto pass_list = Array(); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop(true)); - pass_list.push_back(tir::transform::StorageRewrite()); - transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", - runtime::String(g_var->name_hint)); - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); - lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); - } catch (const dmlc::Error& e) { - return false; - } - if (!Verify(lowered)) { - return false; - } - } + IRModule lowered{nullptr}; + auto pass_list = tir::GetVTCMCompactionPasses(); + transform::PassContext pass_ctx = transform::PassContext::Current(); + lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); + if (!Verify(lowered)) { + return false; } return true; } diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 8680f57e4cfd..3a41c5ac5a25 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -109,6 +110,18 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") } }); +bool VerifyVTCMLimit(const IRModule& mod, Integer limit) { + auto all_sizes = CalculateAllocatedBytes(mod); + for (const auto& kv : all_sizes) { + auto sizes = kv.second; + const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); + if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { + return false; + } + } + return true; +} + bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); @@ -127,6 +140,26 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; } +Array GetVTCMCompactionPasses() { + auto pass_list = Array(); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tir::transform::StorageRewrite()); + return pass_list; +} + +TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { + return GetVTCMCompactionPasses(); +}); + namespace transform { Pass VerifyVTCMLimit(Optional default_target) {