Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ TVM_DLL size_t CalculateConstantBytes(const PrimFunc& func, const Integer& const
TVM_DLL size_t CalculateWorkspaceBytes(const PrimFunc& func,
const Integer& workspace_byte_alignment);

/*!
* \brief Calculate the allocated memory per scope in bytes needed inside the TIR PrimFunc
* \param func The TIR PrimFunc for which the the allocated memory size to be calculated
*/
TVM_DLL tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func);

/*!
* \brief Detect the lowest common ancestor(LCA) of buffer access, including both high-level
* access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
Expand Down Expand Up @@ -294,6 +300,16 @@ TVM_DLL Pass VerifyMemory();
*/
TVM_DLL Pass VerifyGPUCode(Map<String, PrimExpr> constraints);

/*!
* \brief Pass to checks if the size of the allocated vtcm memory satisfies the limit
*
* \param limit The limit to check.
*
* \returns The pass.
* \sa tvm::tir::CalculateAllocatedBytes
*/
TVM_DLL Pass VerifyVTCMLimit(const Integer& limit);

/*!
* \brief Statically check TIR code for out of bounds array access.
*
Expand Down
33 changes: 27 additions & 6 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def set_task(self, task):
)

def get_build_kwargs(self):
kwargs = {}
kwargs = {"checks": {}}
if (
"cuda" in self.task.target.keys
or "opencl" in self.task.target.keys
Expand All @@ -340,13 +340,15 @@ def get_build_kwargs(self):
remote = request_remote(self.key, self.host, self.port)
dev = remote.device(str(self.task.target), 0)
max_dims = dev.max_thread_dimensions
kwargs["check_gpu"] = {
kwargs["checks"]["gpu"] = {
"max_shared_memory_per_block": dev.max_shared_memory_per_block,
"max_threads_per_block": dev.max_threads_per_block,
"max_thread_x": max_dims[0],
"max_thread_y": max_dims[1],
"max_thread_z": max_dims[2],
}
if "hexagon" in self.task.target.keys:
kwargs["checks"]["hexagon"] = {"vtcm_capacity": self.task.target.vtcm_capacity}

return kwargs

Expand Down Expand Up @@ -493,11 +495,11 @@ def set_task(self, task):
return server, tracker


def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option=None):
def _build_func_common(measure_input, runtime=None, checks=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.canon_target_and_host(target, task.target_host)

checks = checks or {}
with target:
s, args = task.instantiate(config)

Expand Down Expand Up @@ -526,8 +528,10 @@ def _build_func_common(measure_input, runtime=None, check_gpu=None, build_option
current_add_lower_pass = list(current_config["tir.add_lower_pass"])
else:
current_add_lower_pass = []
if check_gpu:
current_add_lower_pass.append((2, gpu_verify_pass(**check_gpu)))
if checks.get("gpu"):
current_add_lower_pass.append((2, gpu_verify_pass(**checks.get("gpu"))))
if checks.get("hexagon"):
current_add_lower_pass.append((2, vtcm_verify_pass(**checks.get("hexagon"))))
current_config["tir.add_lower_pass"] = current_add_lower_pass

with tvm.ir.transform.PassContext(
Expand Down Expand Up @@ -872,3 +876,20 @@ def verify_pass(f, *_):
return f

return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)


def vtcm_verify_pass(**kwargs):
"""Verify the validity of a hexagon kernel.
This pass will check vtcm memory usage.
"""

def verify_pass(f, *_):
sizes = tvm.tir.analysis.calculate_allocated_bytes(f)
vtcm_capacity = kwargs.get("vtcm_capacity", 0)
vtcm_allocated = sizes.get("global.vtcm", 0)
if 0 < vtcm_capacity < vtcm_allocated:
raise InstantiationError("Skipped because of invalid vtcm memory usage limit")

return f

return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
8 changes: 8 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,10 @@ def thread_warp_size(self):
def max_function_args(self):
return int(self.attrs.get("max_function_args", -1))

@property
def vtcm_capacity(self):
return int(self.attrs.get("vtcm-capacity", 0))

@property
def device_name(self):
return str(self.attrs.get("device", ""))
Expand Down Expand Up @@ -642,6 +646,8 @@ def hexagon(cpu_ver="v66", **kwargs):
Whether to use IEEE HVX instructions
num_cores : int (default: 4)
The number of HVX threads. This attribute is required by meta scheduler.
vtcm_capacity: int (default: 0)
Hexagon VTCM capacity limitation. If the value is 0, the capacity is treated as unbounded.

Note: Floating point support in HVX requires LLVM 14+.
"""
Expand Down Expand Up @@ -675,6 +681,7 @@ def get_arch_version(cpu_ver):
"llvm_options": None,
"use_qfloat": arch_version >= 68,
"use_ieee_fp": False,
"vtcm_capacity": 0,
}
config.update(kwargs)

Expand Down Expand Up @@ -748,6 +755,7 @@ def create_llvm_options(cpu_ver, config): # pylint: disable=unused-argument

num_cores = config["num_cores"] if "num_cores" in kwargs else 4
args_list.append("--num-cores=%d" % num_cores)
args_list.append("--vtcm-capacity=%d" % config["vtcm_capacity"])

return Target(" ".join(["hexagon"] + args_list))

Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,22 @@ def calculate_constant_bytes(func: PrimFunc, constant_byte_alignment: int) -> in
return _ffi_api.calculate_constant_bytes(func, constant_byte_alignment) # type: ignore


def calculate_allocated_bytes(func: PrimFunc) -> Dict[str, int]:
"""Calculate allocated memory per memory scope required by TIR PrimFuncs.

Parameters
----------
func: tvm.tir.PrimFunc
The function to be detected.

Returns
-------
result : Dict[String, int]
Allocated memory size per scope in bytes.
"""
return _ffi_api.calculate_allocated_bytes(func) # type: ignore


def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
"""Detect the lowest common ancestor(LCA) of buffer access, including both high-level
access(BufferLoad, BufferStore) and low-level access(Load, Store and opaque access).
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,17 @@ def VerifyMemory():
return _ffi_api.VerifyMemory() # type: ignore


def VerifyVTCMLimit(limit: int):
"""Verify if the size of the allocated vtcm memory satisfies the limit.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.VerifyVTCMLimit(limit) # type: ignore


# pylint: disable=no-else-return,inconsistent-return-statements
def HoistIfThenElse(variant: Optional[str] = None):
"""Hoist loop-invariant IfThenElse nodes to outside the eligible loops.
Expand Down
7 changes: 7 additions & 0 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1401,6 +1401,13 @@ void GetPerStoreFeaturesWorkerFunc(const SearchTask& task, const State& state, i
const auto& optimize = tir::transform::Sequential(pass_list);
optimize(mod);
}
if (IsHexagonTask(task)) {
Target target = task->target;
const auto vtcm_capacity = target->GetAttr<Integer>("vtcm-capacity").value().IntValue();
const auto& optimize =
tir::transform::Sequential({tir::transform::VerifyVTCMLimit(vtcm_capacity)});
optimize(mod);
}
const auto& optimize =
tir::transform::Sequential(Array<tvm::transform::Pass>{tir::transform::Simplify()});
mod = optimize(std::move(mod));
Expand Down
5 changes: 5 additions & 0 deletions src/auto_scheduler/search_policy/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,11 @@ inline bool IsGPUTask(const SearchTask& task) {
device_type == kDLMetal || device_type == kDLROCM || device_type == kOpenGL;
}

/*! \brief Return whether the search task is targeting a Hexagon. */
inline bool IsHexagonTask(const SearchTask& task) {
return (task)->target->GetTargetDeviceType() == kDLHexagon;
}

/*! \brief Return whether the search task is targeting a CUDA GPU. */
inline bool IsCUDATask(const SearchTask& task) {
return (task)->target->GetTargetDeviceType() == kDLCUDA;
Expand Down
17 changes: 15 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_async_copy", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.merge_async_commit_queue_scope", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.instrument_lwp", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.dma_bypass_cache", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.vtcm_capacity", Integer);

using tvm::Array;
using tvm::transform::Pass;
Expand Down Expand Up @@ -225,8 +226,6 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
if (!disable_storage_rewrite) {
pass_list.push_back(tir::transform::StorageRewrite());
}
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
pass_list.push_back(tir::transform::LowerVtcmAlloc());
bool use_async_copy = pass_ctx->GetConfig<Bool>("tir.use_async_copy", Bool(false)).value();

if (use_async_copy) {
Expand Down Expand Up @@ -532,11 +531,25 @@ runtime::Module build(const IRModule& funcs, const Target& target_arg,
return TIRToRuntime(inputs, target_host);
}

int64_t GetVTCMCapacity(Target target, const transform::PassContext& pass_ctx) {
if (!target.defined()) target = Target::Current(/*allow_not_defined=*/true);
if (target.defined() && target->kind->name == "hexagon") {
auto value = Downcast<Integer>(target->attrs.at("vtcm-capacity"))->value;
if (value > 0) return value;
}
return pass_ctx->GetConfig<Integer>("tir.vtcm_capacity", Integer(0)).value()->value;
}

transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) {
transform::PassContext pass_ctx = transform::PassContext::Current();

Array<Pass> mixed_pass_list;

// VerifyVTCMLimit must occur before LowerVtcmAlloc
mixed_pass_list.push_back(tir::transform::VerifyVTCMLimit(GetVTCMCapacity(target, pass_ctx)));
// LowerVtcmAlloc must occur after any transformations that modify memory allocation locations
mixed_pass_list.push_back(tir::transform::LowerVtcmAlloc());

mixed_pass_list.push_back(tir::transform::BindTarget(target));

mixed_pass_list.push_back(tir::transform::VerifyMemory());
Expand Down
1 change: 1 addition & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
.add_attr_option<String>("mtriple")
.add_attr_option<Array<String>>("llvm-options")
.add_attr_option<Integer>("num-cores")
.add_attr_option<Integer>("vtcm-capacity")
.set_default_keys({"hexagon"});

TVM_REGISTER_TARGET_KIND("stackvm", kDLCPU);
Expand Down
117 changes: 117 additions & 0 deletions src/tir/analysis/calculate_allocated_memory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tir/analysis/calculate_allocated_memory.cc
* \brief Calculate allocated memory per memory scope required by PrimFuncs.
*/
#include <tvm/arith/analyzer.h>
#include <tvm/runtime/container/map.h>
#include <tvm/runtime/device_api.h>
#include <tvm/tir/analysis.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/usmp/utils.h>

#include <algorithm>
#include <map>
#include <unordered_map>

namespace tvm {
namespace tir {

template <typename T>
class AllocationCalculator : public StmtExprVisitor {
public:
AllocationCalculator() = default;
tvm::Map<String, Integer> operator()(const PrimFunc& func);

private:
void VisitStmt_(const T* op) override;
std::unordered_map<std::string, int64_t> _max_size;
std::unordered_map<std::string, int64_t> _current_size;
};

template <typename T>
tvm::Map<String, Integer> AllocationCalculator<T>::operator()(const PrimFunc& func) {
this->VisitStmt(func->body);
tvm::Map<String, Integer> res;
for (auto [k, v] : _max_size) {
res.Set(String(k), Integer(v));
}
return res;
}

std::string GetStorageScope(const Var& var) {
auto* ptr = var->type_annotation.as<PointerTypeNode>();
ICHECK(ptr) << "Buffer Var's type annotation must be of PointerType";
return ptr->storage_scope;
}

template <typename T>
void AllocationCalculator<T>::VisitStmt_(const T* op) {
std::string storage_scope = GetStorageScope(op->buffer_var);
auto search = _current_size.find(storage_scope);
if (search == _current_size.end()) {
_current_size[storage_scope] = 0;
_max_size[storage_scope] = 0;
}
auto size = op->ConstantAllocationSize() * op->dtype.bytes() * op->dtype.lanes();
_current_size[storage_scope] += size;
_max_size[storage_scope] = std::max(_current_size[storage_scope], _max_size[storage_scope]);
StmtExprVisitor::VisitStmt(op->body);
_current_size[storage_scope] -= size;
}

tvm::Map<String, Integer> CalculateAllocatedBytes(const PrimFunc& func) {
return AllocationCalculator<AllocateNode>()(func);
}

TVM_REGISTER_GLOBAL("tir.analysis.calculate_allocated_bytes").set_body_typed([](PrimFunc func) {
return CalculateAllocatedBytes(func);
});

namespace transform {

Pass VerifyVTCMLimit(const Integer& limit) {
auto pass_func = [=](IRModule mod, PassContext ctx) {
for (auto kv : mod->functions) {
if (auto* n = kv.second.as<PrimFuncNode>()) {
auto func = GetRef<PrimFunc>(n);
auto sizes = CalculateAllocatedBytes(func);
const auto vtcm_allocated = sizes.Get("global.vtcm").value_or(0);
if (limit.IntValue() > 0 && vtcm_allocated.IntValue() > limit.IntValue()) {
LOG(FATAL) << "RuntimeError: The global.vtcm memory allocation limit has been "
"exceeded(allocated: "
<< vtcm_allocated << ", limit: " << limit << ").\n"
<< "In function\n"
<< func;
}
}
}
return mod;
};
return tvm::transform::CreateModulePass(pass_func, 0, "tir.calculate_allocated_bytes", {});
}

TVM_REGISTER_GLOBAL("tir.transform.VerifyVTCMLimit").set_body_typed(VerifyVTCMLimit);

} // namespace transform
} // namespace tir
} // namespace tvm
4 changes: 2 additions & 2 deletions tests/python/contrib/test_hexagon/infrastructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,7 @@ def quantize_np(arr_np: numpy.ndarray, dtype: str):
return quant_np, scale, zero_point


def get_hexagon_target(cpu_ver: str) -> tvm.target.Target:
def get_hexagon_target(cpu_ver: str, **kwargs) -> tvm.target.Target:
"""Creates a Hexagon target"""
target = tvm.target.hexagon(cpu_ver)
target = tvm.target.hexagon(cpu_ver, **kwargs)
return tvm.target.Target(target, host=target)
Loading