diff --git a/include/tvm/tir/usmp/algorithms.h b/include/tvm/tir/usmp/algorithms.h index 77276a2c931c..e2f2b6fb73f3 100644 --- a/include/tvm/tir/usmp/algorithms.h +++ b/include/tvm/tir/usmp/algorithms.h @@ -54,6 +54,17 @@ Map GreedyBySize(const Array& buffer_inf Map GreedyByConflicts(const Array& buffer_info_arr, const Integer& memory_pressure); +/*! + * \brief The Hill-Climb algorithm to plan memory + * + * This will perform a hill climbing algorithm in deciding the offsets + * within provided Pools. + * + * \return A Map of BufferInfo objects and their associated PoolAllocation + */ +Map HillClimb(const Array& buffer_info_arr, + const Integer& memory_pressure); + } // namespace algo } // namespace usmp } // namespace tir diff --git a/src/tir/usmp/unified_static_memory_planner.cc b/src/tir/usmp/unified_static_memory_planner.cc index 5a2125077566..3b941d3cc021 100644 --- a/src/tir/usmp/unified_static_memory_planner.cc +++ b/src/tir/usmp/unified_static_memory_planner.cc @@ -46,7 +46,8 @@ static constexpr const char* kDefaultAlgo = "greedy_by_size"; static std::unordered_map( const Array&, const Integer&)>> algorithms{{"greedy_by_size", algo::GreedyBySize}, - {"greedy_by_conflicts", algo::GreedyByConflicts}}; + {"greedy_by_conflicts", algo::GreedyByConflicts}, + {"hill_climb", algo::HillClimb}}; IRModule PlanMemory(const IRModule& mod, String algo) { VLOG(1) << "workspace required = " << CalculateModuleWorkspaceSize(mod); @@ -55,7 +56,7 @@ IRModule PlanMemory(const IRModule& mod, String algo) { Array buffer_info_arr = CreateArrayBufferInfo(buffer_info_analysis->buffer_info_stmts); CHECK(algorithms.count(algo)) << "The selected USMP algorithm : " << algo - << "is not defined. Please define it in the above algorithms map."; + << " is not defined. Please define it in the above algorithms map."; Map buffer_info_pool_allocations = algorithms[algo](buffer_info_arr, buffer_info_analysis->memory_pressure); Map stmt_pool_allocations = AssignStmtPoolAllocations( diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 47495aaa16c8..6a040d9a9e79 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -212,6 +212,7 @@ def test_byoc_microtvm(merge_compiler_regions): [ (MOBILENET_V1_URL, "greedy_by_size", 4845696), (MOBILENET_V1_URL, "greedy_by_conflicts", 4444288), + (MOBILENET_V1_URL, "hill_climb", 3240064), ], ) def test_tflite_model(model_url, usmp_algo, workspace_size):