Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/tir/usmp/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,17 @@ Map<BufferInfo, PoolAllocation> GreedyBySize(const Array<BufferInfo>& buffer_inf
Map<BufferInfo, PoolAllocation> GreedyByConflicts(const Array<BufferInfo>& buffer_info_arr,
const Integer& memory_pressure);

/*!
* \brief The Hill-Climb algorithm to plan memory
*
* This will perform a hill climbing algorithm in deciding the offsets
* within provided Pools.
*
* \return A Map of BufferInfo objects and their associated PoolAllocation
*/
Map<BufferInfo, PoolAllocation> HillClimb(const Array<BufferInfo>& buffer_info_arr,
const Integer& memory_pressure);

} // namespace algo
} // namespace usmp
} // namespace tir
Expand Down
5 changes: 3 additions & 2 deletions src/tir/usmp/unified_static_memory_planner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ static constexpr const char* kDefaultAlgo = "greedy_by_size";
static std::unordered_map<String, std::function<Map<BufferInfo, PoolAllocation>(
const Array<BufferInfo>&, const Integer&)>>
algorithms{{"greedy_by_size", algo::GreedyBySize},
{"greedy_by_conflicts", algo::GreedyByConflicts}};
{"greedy_by_conflicts", algo::GreedyByConflicts},
{"hill_climb", algo::HillClimb}};

IRModule PlanMemory(const IRModule& mod, String algo) {
VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod);
Expand All @@ -55,7 +56,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) {
Array<BufferInfo> buffer_info_arr =
CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts);
CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo
<< "is not defined. Please define it in the above algorithms map.";
<< " is not defined. Please define it in the above algorithms map.";
Map<BufferInfo, PoolAllocation> buffer_info_pool_allocations =
algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure);
Map<Stmt, PoolAllocation> stmt_pool_allocations = AssignStmtPoolAllocations(
Expand Down
1 change: 1 addition & 0 deletions tests/python/relay/aot/test_crt_aot_usmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_byoc_microtvm(merge_compiler_regions):
[
(MOBILENET_V1_URL, "greedy_by_size", 4845696),
(MOBILENET_V1_URL, "greedy_by_conflicts", 4444288),
(MOBILENET_V1_URL, "hill_climb", 3240064),
],
)
def test_tflite_model(model_url, usmp_algo, workspace_size):
Expand Down