From a2040db182180f2aedde5ed481591cc4e658be40 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 1 Dec 2022 11:16:03 +0000 Subject: [PATCH 1/2] [TIR][Hexagon] Add vtcm memory capacity verification for Hexagon target --- include/tvm/tir/analysis.h | 16 +++ python/tvm/autotvm/measure/measure_methods.py | 33 ++++- python/tvm/target/target.py | 8 ++ python/tvm/tir/analysis/analysis.py | 16 +++ python/tvm/tir/transform/transform.py | 11 ++ src/auto_scheduler/feature.cc | 7 ++ src/auto_scheduler/search_policy/utils.h | 5 + src/driver/driver_api.cc | 17 ++- src/target/target_kind.cc | 1 + .../analysis/calculate_allocated_memory.cc | 117 ++++++++++++++++++ .../contrib/test_hexagon/infrastructure.py | 4 +- .../python/contrib/test_hexagon/test_vtcm.py | 53 +++++--- ...tir_analysis_calculate_allocated_memory.py | 101 +++++++++++++++ 13 files changed, 364 insertions(+), 25 deletions(-) create mode 100644 src/tir/analysis/calculate_allocated_memory.cc create mode 100644 tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index e9796eca6505..cb31a7e5ee96 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -217,6 +217,12 @@ TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& const TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func, const Integer& workspace_byte_alignment); +/*! + * \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc + * \param func The TIR PrimFunc for which the the allocated memory size to be calculated + */ +TVM_DLL tvm::Map CalculateAllocatedBytes(const PrimFunc& func); + /*! * \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level * access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). @@ -294,6 +300,16 @@ TVM_DLL Pass VerifyMemory(); */ TVM_DLL Pass VerifyGPUCode(Map constraints); +/*! + * \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit + * + * \param limit The limit to check. + * + * \returns The pass. + * \sa tvm::tir::CalculateAllocatedBytes + */ +TVM_DLL Pass VerifyVTCMLimit(const Integer& limit); + /*! * \brief Statically check TIR code for out of bounds array access. * diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py index 8fc0da89c4c6..f1c14c3cd914 100644 --- a/python/tvm/autotvm/measure/measure_methods.py +++ b/python/tvm/autotvm/measure/measure_methods.py @@ -330,7 +330,7 @@ def set_task(self, task): ) def get_build_kwargs(self): - kwargs = {} + kwargs = {"checks": {}} if ( "cuda" in self.task.target.keys or "opencl" in self.task.target.keys @@ -340,13 +340,15 @@ def get_build_kwargs(self): remote = request_remote(self.key, self.host, self.port) dev = remote.device(str(self.task.target), 0) max_dims = dev.max_thread_dimensions - kwargs["check_gpu"] = { + kwargs["checks"]["gpu"] = { "max_shared_memory_per_block": dev.max_shared_memory_per_block, "max_threads_per_block": dev.max_threads_per_block, "max_thread_x": max_dims[0], "max_thread_y": max_dims[1], "max_thread_z": max_dims[2], } + if "hexagon" in self.task.target.keys: + kwargs["checks"]["hexagon"] = {"vtcm_capacity": self.task.target.vtcm_capacity} return kwargs @@ -493,11 +495,11 @@ def set_task(self, task): return server, tracker -def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None): +def _build_func_common(measure_input, runtime=None, checks=None, build_option=None): """Common part for building a configuration""" target, task, config = measure_input target, task.target_host = Target.canon_target_and_host(target, task.target_host) - + checks = checks or {} with target: s, args = task.instantiate(config) @@ -526,8 +528,10 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option current_add_lower_pass = list(current_config["tir.add_lower_pass"]) else: current_add_lower_pass = [] - if check_gpu: - current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu))) + if checks.get("gpu"): + current_add_lower_pass.append((2, gpu_verify_pass(**checks.get("gpu")))) + if checks.get("hexagon"): + current_add_lower_pass.append((2, vtcm_verify_pass(**checks.get("hexagon")))) current_config["tir.add_lower_pass"] = current_add_lower_pass with tvm.ir.transform.PassContext( @@ -872,3 +876,20 @@ def verify_pass(f, *_): return f return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0) + + +def vtcm_verify_pass(**kwargs): + """Verify the validity of a hexagon kernel. + This pass will check vtcm memory usage. + """ + + def verify_pass(f, *_): + sizes = tvm.tir.analysis.calculate_allocated_bytes(f) + vtcm_capacity = kwargs.get("vtcm_capacity", 0) + vtcm_allocated = sizes.get("global.vtcm", 0) + if 0 < vtcm_capacity < vtcm_allocated: + raise InstantiationError("Skipped because of invalid vtcm memory usage limit") + + return f + + return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0) diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py index 7081f992afd9..06e1776965c2 100644 --- a/python/tvm/target/target.py +++ b/python/tvm/target/target.py @@ -182,6 +182,10 @@ def thread_warp_size(self): def max_function_args(self): return int(self.attrs.get("max_function_args", -1)) + @property + def vtcm_capacity(self): + return int(self.attrs.get("vtcm-capacity", 0)) + @property def device_name(self): return str(self.attrs.get("device", "")) @@ -642,6 +646,8 @@ def hexagon(cpu_ver="v66", **kwargs): Whether to use IEEE HVX instructions num_cores : int (default: 4) The number of HVX threads. This attribute is required by meta scheduler. + vtcm_capacity: int (default: 0) + Hexagon VTCM capacity limitation. If the value is 0, the capacity is treated as unbounded. Note: Floating point support in HVX requires LLVM 14+. """ @@ -675,6 +681,7 @@ def get_arch_version(cpu_ver): "llvm_options": None, "use_qfloat": arch_version >= 68, "use_ieee_fp": False, + "vtcm_capacity": 0, } config.update(kwargs) @@ -748,6 +755,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument num_cores = config["num_cores"] if "num_cores" in kwargs else 4 args_list.append("--num-cores=%d" % num_cores) + args_list.append("--vtcm-capacity=%d" % config["vtcm_capacity"]) return Target(" ".join(["hexagon"] + args_list)) diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index efb869efd6dc..45b1f745c3de 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -201,6 +201,22 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore +def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]: + """Calculate allocated memory per memory scope required by TIR PrimFuncs. + + Parameters + ---------- + func: tvm.tir.PrimFunc + The function to be detected. + + Returns + ------- + result : Dict[String, int] + Allocated memory size per scope in bytes. + """ + return _ffi_api.calculate_allocated_bytes(func) # type: ignore + + def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: """Detect the lowest common ancestor(LCA) of buffer access, including both high-level access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access). diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 82533a2f9f5a..81b90d5f4051 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -611,6 +611,17 @@ def VerifyMemory(): return _ffi_api.VerifyMemory() # type: ignore +def VerifyVTCMLimit(limit: int): + """Verify if the size of the allocated vtcm memory satisfies the limit. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.VerifyVTCMLimit(limit) # type: ignore + + # pylint: disable=no-else-return,inconsistent-return-statements def HoistIfThenElse(variant: Optional[str] = None): """Hoist loop-invariant IfThenElse nodes to outside the eligible loops. diff --git a/src/auto_scheduler/feature.cc b/src/auto_scheduler/feature.cc index 2f993c0c8b82..4ce7ad13bc60 100644 --- a/src/auto_scheduler/feature.cc +++ b/src/auto_scheduler/feature.cc @@ -1401,6 +1401,13 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i const auto& optimize = tir::transform::Sequential(pass_list); optimize(mod); } + if (IsHexagonTask(task)) { + Target target = task->target; + const auto vtcm_capacity = target->GetAttr("vtcm-capacity").value().IntValue(); + const auto& optimize = + tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)}); + optimize(mod); + } const auto& optimize = tir::transform::Sequential(Array{tir::transform::Simplify()}); mod = optimize(std::move(mod)); diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 44b60de1d7ad..ca8979c0e829 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -58,6 +58,11 @@ inline bool IsGPUTask(const SearchTask& task) { device_type == kDLMetal || device_type == kDLROCM || device_type == kOpenGL; } +/*! \brief Return whether the search task is targeting a Hexagon. */ +inline bool IsHexagonTask(const SearchTask& task) { + return (task)->target->GetTargetDeviceType() == kDLHexagon; +} + /*! \brief Return whether the search task is targeting a CUDA GPU. */ inline bool IsCUDATask(const SearchTask& task) { return (task)->target->GetTargetDeviceType() == kDLCUDA; diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 90676e0b840b..10d9e8023a61 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -54,6 +54,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.dma_bypass_cache", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer); using tvm::Array; using tvm::transform::Pass; @@ -225,8 +226,6 @@ Array CreatePassList(bool disable_loop_partition) { if (!disable_storage_rewrite) { pass_list.push_back(tir::transform::StorageRewrite()); } - // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations - pass_list.push_back(tir::transform::LowerVtcmAlloc()); bool use_async_copy = pass_ctx->GetConfig("tir.use_async_copy", Bool(false)).value(); if (use_async_copy) { @@ -532,11 +531,25 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg, return TIRToRuntime(inputs, target_host); } +int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) { + if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true); + if (target.defined() && target->kind->name == "hexagon") { + auto value = Downcast(target->attrs.at("vtcm-capacity"))->value; + if (value > 0) return value; + } + return pass_ctx->GetConfig("tir.vtcm_capacity", Integer(0)).value()->value; +} + transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) { transform::PassContext pass_ctx = transform::PassContext::Current(); Array mixed_pass_list; + // VerifyVTCMLimit must occur before LowerVtcmAlloc + mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx))); + // LowerVtcmAlloc must occur after any transformations that modify memory allocation locations + mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc()); + mixed_pass_list.push_back(tir::transform::BindTarget(target)); mixed_pass_list.push_back(tir::transform::VerifyMemory()); diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index ef350004ad52..a87bb92c483b 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -421,6 +421,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon) .add_attr_option("mtriple") .add_attr_option>("llvm-options") .add_attr_option("num-cores") + .add_attr_option("vtcm-capacity") .set_default_keys({"hexagon"}); TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU); diff --git a/src/tir/analysis/calculate_allocated_memory.cc b/src/tir/analysis/calculate_allocated_memory.cc new file mode 100644 index 000000000000..01457508ab95 --- /dev/null +++ b/src/tir/analysis/calculate_allocated_memory.cc @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/calculate_allocated_memory.cc + * \brief Calculate allocated memory per memory scope required by PrimFuncs. + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { + +template +class AllocationCalculator : public StmtExprVisitor { + public: + AllocationCalculator() = default; + tvm::Map operator()(const PrimFunc& func); + + private: + void VisitStmt_(const T* op) override; + std::unordered_map _max_size; + std::unordered_map _current_size; +}; + +template +tvm::Map AllocationCalculator::operator()(const PrimFunc& func) { + this->VisitStmt(func->body); + tvm::Map res; + for (auto [k, v] : _max_size) { + res.Set(String(k), Integer(v)); + } + return res; +} + +std::string GetStorageScope(const Var& var) { + auto* ptr = var->type_annotation.as(); + ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType"; + return ptr->storage_scope; +} + +template +void AllocationCalculator::VisitStmt_(const T* op) { + std::string storage_scope = GetStorageScope(op->buffer_var); + auto search = _current_size.find(storage_scope); + if (search == _current_size.end()) { + _current_size[storage_scope] = 0; + _max_size[storage_scope] = 0; + } + auto size = op->ConstantAllocationSize() * op->dtype.bytes() * op->dtype.lanes(); + _current_size[storage_scope] += size; + _max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]); + StmtExprVisitor::VisitStmt(op->body); + _current_size[storage_scope] -= size; +} + +tvm::Map CalculateAllocatedBytes(const PrimFunc& func) { + return AllocationCalculator()(func); +} + +TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) { + return CalculateAllocatedBytes(func); +}); + +namespace transform { + +Pass VerifyVTCMLimit(const Integer& limit) { + auto pass_func = [=](IRModule mod, PassContext ctx) { + for (auto kv : mod->functions) { + if (auto* n = kv.second.as()) { + auto func = GetRef(n); + auto sizes = CalculateAllocatedBytes(func); + const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0); + if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) { + LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been " + "exceeded(allocated: " + << vtcm_allocated << ", limit: " << limit << ").\n" + << "In function\n" + << func; + } + } + } + return mod; + }; + return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/contrib/test_hexagon/infrastructure.py b/tests/python/contrib/test_hexagon/infrastructure.py index c03701f83ccc..9431507e9cc9 100644 --- a/tests/python/contrib/test_hexagon/infrastructure.py +++ b/tests/python/contrib/test_hexagon/infrastructure.py @@ -324,7 +324,7 @@ def quantize_np(arr_np: numpy.ndarray, dtype: str): return quant_np, scale, zero_point -def get_hexagon_target(cpu_ver: str) -> tvm.target.Target: +def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target: """Creates a Hexagon target""" - target = tvm.target.hexagon(cpu_ver) + target = tvm.target.hexagon(cpu_ver, **kwargs) return tvm.target.Target(target, host=target) diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 11188436a318..668346b96c37 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -16,9 +16,11 @@ # under the License. """VTCM Tests""" +import pytest import tvm.testing from tvm import tir from tvm.script import tir as T +from .infrastructure import get_hexagon_target @T.prim_func @@ -31,8 +33,7 @@ def scale_by_two(buffer_a: T.Buffer[(8192,), "int8"], buffer_c: T.Buffer[(8192,) buffer_c[i] = buffer_a[i] * T.int8(2) -def test_vtcm_lowering(): - """Test lowering with vtcm mem scope""" +def get_scale_by_two_schedule(): mod = tvm.IRModule.from_expr(scale_by_two.with_attr("global_symbol", "main")) sch = tir.Schedule(mod, debug_mask="all") block_c = sch.get_block("C") @@ -40,23 +41,45 @@ def test_vtcm_lowering(): outer, _, _, _ = sch.split(flat, factors=[8, 4, 2, 128]) cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm") sch.compute_at(cache_block, outer) - lowered = tvm.lower(sch.mod["main"]) + return sch - def ir_module_has_allocate_nodes(irmod): - nallocs = 0 - def _visit(stmt): - nonlocal nallocs - if isinstance(stmt, tvm.tir.Allocate): - nallocs += 1 +def test_vtcm_building(): + """Test building with vtcm mem scope""" + sch = get_scale_by_two_schedule() + target = get_hexagon_target("v68") + built = tvm.build(sch.mod, target=target) + assert "global.vtcm" in built.get_source("asm") - tvm.tir.stmt_functor.post_order_visit(irmod["main"].body, _visit) - return nallocs - assert not ir_module_has_allocate_nodes(lowered), ( - "AllocateNode found in lowered IRModule, " - "VTCM allocations should have been lowered to tir.nd_mem_alloc_with_scope" - ) +@pytest.mark.parametrize("vtcm_capacity,limited", [(8192, False), (1024, False), (128, True)]) +def test_vtcm_limit(vtcm_capacity, limited): + """Test building with vtcm mem scope limit""" + sch = get_scale_by_two_schedule() + + def _raises_exception(f): + try: + f() + except tvm._ffi.base.TVMError: + return True + return False + + target = get_hexagon_target("v68", vtcm_capacity=vtcm_capacity) + + assert ( + _raises_exception(lambda: tvm.build(sch.mod, target=target)) == limited + ), "Case 1 - arg. VTCM memory allocation limiter does not work correctly " + + with target: + assert ( + _raises_exception(lambda: tvm.build(sch.mod)) == limited + ), "Case 2 - with.VTCM memory allocation limiter does not work correctly " + + with tvm.transform.PassContext(config={"tir.vtcm_capacity": vtcm_capacity}): + assert ( + _raises_exception(lambda: tvm.build(sch.mod, target=get_hexagon_target("v68"))) + == limited + ), "Case 3 - context. VTCM memory allocation limiter does not work correctly " if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py new file mode 100644 index 000000000000..1a2d50ef5d7f --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_calculate_allocated_memory.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest + +import tvm +from tvm import tir +from tvm.script import tir as T + + +@T.prim_func +def scale_by_two(a: T.Buffer[(128,), "int8"], c: T.Buffer[(128,), "int8"]): + for i in T.serial(128): + with T.block("C"): + c[i] = a[i] * T.int8(2) + + +@T.prim_func +def scale_by_two_three(a: T.Buffer[(128,), "int8"], c: T.Buffer[(128,), "int8"]): + B = T.alloc_buffer([128], dtype="int8", scope="global.vtcm") + for i in T.serial(128): + with T.block("B"): + B[i] = a[i] * T.int8(2) + for i in T.serial(128): + with T.block("C"): + c[i] = B[i] * T.int8(3) + + +@pytest.mark.parametrize("primFunc,size", [(scale_by_two, 128), (scale_by_two_three, 256)]) +def test_scale_by(primFunc, size): + """Test calculate allocated bytes per scope""" + mod = tvm.IRModule.from_expr(primFunc.with_attr("global_symbol", "main")) + sch = tir.Schedule(mod, debug_mask="all") + block_c = sch.get_block("C") + (flat,) = sch.get_loops(block_c) + cache_block = sch.cache_read(block_c, 0, storage_scope="global.vtcm") + sch.compute_at(cache_block, flat) + + mod = sch.mod + mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert sizes.get("global.vtcm", 0) == size + + +@T.prim_func +def matmul_mix_scope(a: T.handle, b: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [128, 128], scope="global") + B = T.match_buffer(b, [128, 128], scope="global") + C = T.match_buffer(c, [128, 128], scope="global") + A_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global.texture") + B_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global.texture") + C_allocated = T.alloc_buffer([128, 128], dtype="float32", scope="global") + + for i, j in T.grid(128, 128): + with T.block("A.allocated"): + A_allocated[i, j] = A[i, j] + for i, j in T.grid(128, 128): + with T.block("B.allocated"): + B_allocated[i, j] = B[i, j] + + for i, j, k in T.grid(128, 128, 128): + with T.block("update"): + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C_allocated[vi, vj] = 0.0 + C_allocated[vi, vj] = C[vi, vj] + A_allocated[vi, vk] * B_allocated[vj, vk] + + for i, j in T.grid(128, 128): + with T.block("C"): + C[i, j] = C_allocated[i, j] + + +@pytest.mark.parametrize( + "scope,size", [("global", 65536), ("global.texture", 131072), ("global.texture-nhwc", 0)] +) +def test_matmul_mix_scope(scope, size): + """Test calculate allocated bytes per scope""" + mod = tvm.IRModule({"main": matmul_mix_scope}) + mod = tvm.tir.transform.LowerInitBlock()(mod) + mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.tir.transform.LowerOpaqueBlock()(mod) + sizes = tvm.tir.analysis.calculate_allocated_bytes(mod["main"]) + assert sizes.get(scope, 0) == size + + +if __name__ == "__main__": + tvm.testing.main() From 63a403fd12aa801876df3d41caf03dcd3c832b01 Mon Sep 17 00:00:00 2001 From: Alexey Voronov Date: Thu, 1 Dec 2022 14:36:40 +0000 Subject: [PATCH 2/2] add requires_hexagon for building tests --- tests/python/contrib/test_hexagon/test_vtcm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/python/contrib/test_hexagon/test_vtcm.py b/tests/python/contrib/test_hexagon/test_vtcm.py index 668346b96c37..e71f890740c1 100644 --- a/tests/python/contrib/test_hexagon/test_vtcm.py +++ b/tests/python/contrib/test_hexagon/test_vtcm.py @@ -44,6 +44,7 @@ def get_scale_by_two_schedule(): return sch +@tvm.testing.requires_hexagon def test_vtcm_building(): """Test building with vtcm mem scope""" sch = get_scale_by_two_schedule() @@ -52,6 +53,7 @@ def test_vtcm_building(): assert "global.vtcm" in built.get_source("asm") +@tvm.testing.requires_hexagon @pytest.mark.parametrize("vtcm_capacity,limited", [(8192, False), (1024, False), (128, True)]) def test_vtcm_limit(vtcm_capacity, limited): """Test building with vtcm mem scope limit"""