diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 3b5959e7816d..f4684231f008 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -184,6 +184,22 @@ TVM_DLL bool VerifyMemory(const PrimFunc& func); */ TVM_DLL bool VerifyGPUCode(const PrimFunc& func, Map constraints); +/** + * @brief Utility function to get the list of lowering passes to be applied to calculate the + * compacted VTCM allocation size + * + * @return returns list of passes + */ +TVM_DLL Array GetVTCMCompactionPasses(); + +/*! + * \brief Verifies that the VTCM usage for all prim_funcs in the given IRModule + * \param mod The module to be checked + * \param limit The limit to check. + * \return true if the VTCM usage is within the provided limit. + */ +TVM_DLL bool VerifyVTCMLimit(const IRModule& mod, Integer limit); + /*! * \brief Verifies that the VTCM usage of the given prim_func is within the provided limit. * \param func The function to be checked. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index 387ea0498015..1a5f8b978168 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -18,6 +18,7 @@ # pylint: disable=invalid-name from typing import Dict, List, Union +import tvm from tvm import Object from tvm.ir import IRModule from tvm.tir.expr import Var @@ -384,3 +385,15 @@ def find_anchor_block(mod: IRModule) -> Block: The anchor block if found, None otherwise. """ return _ffi_api.find_anchor_block(mod) # type: ignore # pylint: disable=no-member + + +def get_vtcm_compaction_passes() -> List[tvm.transform.Pass]: + """Utility function to get the list of lowering passes to be applied to calculate thecompacted + VTCM allocation size + + Returns + ------- + result : List[tvm.transform.Pass] + returns list of passes + """ + return _ffi_api.get_vtcm_compaction_passes() # type: ignore # pylint: disable=no-member diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py index 3148360cc2d1..f017aaebbdeb 100644 --- a/python/tvm/topi/hexagon/utils.py +++ b/python/tvm/topi/hexagon/utils.py @@ -21,9 +21,11 @@ """Common hexagon specific utilities""" import math import struct -from typing import Tuple -from tvm import te -from tvm.tir import IndexMap +from typing import Dict, Tuple, Union + +import tvm +from tvm import IRModule, te +from tvm.tir import IndexMap, PrimFunc def n11c_1024c_2d(n, h, w, c): @@ -354,3 +356,47 @@ def within_range(val, dtype): def saturate(x: te.Tensor, dtype: str): """Saturate value for the specified data type""" return te.max(te.min_value(dtype), te.min(x, te.max_value(dtype))) + + +def get_vtcm_allocation_sizes( + func_or_mod: Union[PrimFunc, IRModule], compacted=True +) -> Dict[str, int]: + """Calculate and return the vtcm allocation sizes for all the functions in + the IRModule or just the vtcm size if a single PrimFunc is passed + + Parameters + ---------- + func_or_mod : Union[PrimFunc, IRModule] + PrimFunc or IRModule for which VTCM allocation size is to be calculated + compacted : + Whether to calculate the sizes after applying VTCM lowering passes for + buffer compaction. This helps return the VTCM size that would get + allocated after lowering + + Returns + ------- + result : Dict[str, int] + A dict with function names as keys and vtcm allocated + inside that function as values + + """ + if not isinstance(func_or_mod, (PrimFunc, IRModule)): + raise TypeError( + f"Expected argument to be PrimFunc or IRModule, but received {type(func_or_mod)}" + ) + if isinstance(func_or_mod, tvm.tir.PrimFunc): + mod = tvm.IRModule.from_expr(func_or_mod) + else: + mod = func_or_mod + if compacted: + passes = tvm.tir.analysis.get_vtcm_compaction_passes() + mod = tvm.transform.Sequential(list(passes))(mod) + + result = {} + all_sizes = tvm.tir.analysis.calculate_allocated_bytes(mod) + for func_name, sizes in all_sizes.items(): + if "global.vtcm" in sizes: + result[func_name] = sizes["global.vtcm"] + else: + result[func_name] = 0 + return result diff --git a/src/meta_schedule/postproc/verify_vtcm_limit.cc b/src/meta_schedule/postproc/verify_vtcm_limit.cc index 46bc7486e1df..4de975089653 100644 --- a/src/meta_schedule/postproc/verify_vtcm_limit.cc +++ b/src/meta_schedule/postproc/verify_vtcm_limit.cc @@ -36,48 +36,20 @@ class VerifyVTCMLimitNode : public PostprocNode { } bool Verify(const IRModule& mod) const { - for (const auto& kv : mod->functions) { - if (auto prim_func = kv.second.as()) { - if (!tir::VerifyVTCMLimit(prim_func.value(), vtcm_capacity)) { - return false; - } - } + if (!tir::VerifyVTCMLimit(mod, vtcm_capacity)) { + return false; } return true; } bool Apply(const tir::Schedule& sch) final { IRModule mod = sch->mod(); - for (const auto& kv : mod->functions) { - const GlobalVar& g_var = kv.first; - const BaseFunc& base_func = kv.second; - if (const auto* prim_func = base_func.as()) { - IRModule lowered{nullptr}; - try { - auto pass_list = Array(); - pass_list.push_back(tir::transform::LowerInitBlock()); - pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); - pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); - pass_list.push_back(tir::transform::CompactBufferAllocation()); - pass_list.push_back(tir::transform::LowerMatchBuffer()); - pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); - pass_list.push_back(tir::transform::FlattenBuffer()); - pass_list.push_back(tir::transform::Simplify()); - pass_list.push_back(tir::transform::VectorizeLoop(true)); - pass_list.push_back(tir::transform::StorageRewrite()); - transform::PassContext pass_ctx = transform::PassContext::Current(); - tir::PrimFunc f = WithAttr(GetRef(prim_func), "global_symbol", - runtime::String(g_var->name_hint)); - IRModule mod = IRModule(Map({{GlobalVar(g_var->name_hint), f}})); - lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); - } catch (const dmlc::Error& e) { - return false; - } - if (!Verify(lowered)) { - return false; - } - } + IRModule lowered{nullptr}; + auto pass_list = tir::GetVTCMCompactionPasses(); + transform::PassContext pass_ctx = transform::PassContext::Current(); + lowered = tvm::transform::Sequential(pass_list)(std::move(mod)); + if (!Verify(lowered)) { + return false; } return true; } diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc index 8680f57e4cfd..3a41c5ac5a25 100644 --- a/src/tir/analysis/calculate_allocated_memory.cc +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include @@ -109,6 +110,18 @@ TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes") } }); +bool VerifyVTCMLimit(const IRModule& mod, Integer limit) { + auto all_sizes = CalculateAllocatedBytes(mod); + for (const auto& kv : all_sizes) { + auto sizes = kv.second; + const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); + if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { + return false; + } + } + return true; +} + bool VerifyVTCMLimit(const PrimFunc& func, Integer limit) { auto sizes = CalculateAllocatedBytes(func)["main"]; const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); @@ -127,6 +140,26 @@ int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; } +Array GetVTCMCompactionPasses() { + auto pass_list = Array(); + pass_list.push_back(tir::transform::LowerInitBlock()); + pass_list.push_back(tir::transform::PlanAndUpdateBufferAllocationLocation()); + pass_list.push_back(tir::transform::ConvertBlocksToOpaque()); + pass_list.push_back(tir::transform::CompactBufferAllocation()); + pass_list.push_back(tir::transform::LowerMatchBuffer()); + pass_list.push_back(tir::transform::InjectSoftwarePipeline()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); + pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::Simplify()); + pass_list.push_back(tir::transform::VectorizeLoop(true)); + pass_list.push_back(tir::transform::StorageRewrite()); + return pass_list; +} + +TVM_REGISTER_GLOBAL("tir.analysis.get_vtcm_compaction_passes").set_body_typed([]() { + return GetVTCMCompactionPasses(); +}); + namespace transform { Pass VerifyVTCMLimit(Optional default_target) {