diff --git a/include/tvm/tir/usmp/algo/greedy.h b/include/tvm/tir/usmp/algo/greedy.h new file mode 100644 index 000000000000..8f0ed873593e --- /dev/null +++ b/include/tvm/tir/usmp/algo/greedy.h @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file include/tvm/tir/usmp/algo/greedy.h + * \brief This header file contains helper methods used in greedy algorithms + * for planning memory for USMP + */ +#pragma once +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/*! + * \brief This is the base class for Greedy Algorithms where the sorting + * is specialized in the extended classes based on the greedy criteria. + */ +class GreedyBase { + public: + GreedyBase() {} + /*! + * \brief This function should be implemented by the extended classes to sort the BufferInfo + * objects based on a criteria and then calling PostSortAllocation. + */ + virtual Map PlanMemory(const Array& buffer_info_arr) = 0; + + protected: + /*! + * \brief Rounds up the offset to satisfy the alignement requirement + */ + size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment); + + /*! + * \brief A helper function check whether a offset is valid given the constraints + */ + bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes); + + /*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ + PoolInfo SelectPlacementPool( + const BufferInfo& buf_info, + const std::unordered_map& pool_offsets); + + /*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ + Map PostSortAllocation( + const std::vector& buffer_info_vec); +}; + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/src/tir/usmp/algo/greedy.cc b/src/tir/usmp/algo/greedy.cc index 5e1ce5f289c1..a434d206162f 100644 --- a/src/tir/usmp/algo/greedy.cc +++ b/src/tir/usmp/algo/greedy.cc @@ -39,6 +39,7 @@ #include #include #include +#include #include namespace tvm { @@ -47,109 +48,93 @@ namespace usmp { namespace algo { /*! - * \brief This is the base class for Greedy Algorithms where the sorting - * is specialized in the extended classes based on the greedy criteria. + * \brief Rounds up the offset to satisfy the alignement requirement */ -class GreedyBase { - public: - GreedyBase() {} - /*! - * \brief This function should be implemented by the extended classes to sort the BufferInfo - * objects based on a criteria and then calling PostSortAllocation. - */ - virtual Map PlanMemory(const Array& buffer_info_arr) = 0; +size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, + const int& byte_alignment) { + return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; +} - protected: - /*! - * \brief Rounds up the offset to satisfy the alignement requirement - */ - size_t round_up_to_byte_alignment(const size_t& non_aligned_byte_offset, - const int& byte_alignment) { - return ((non_aligned_byte_offset + byte_alignment - 1) / byte_alignment) * byte_alignment; +/*! + * \brief A helper function check whether a offset is valid given the constraints + */ +bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, + const size_t& size_bytes) { + if (candidate_pool->size_hint_bytes == -1) { + // this means pool is not bounded + return true; + } + auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); + auto max_address = next_offset + size_bytes; + if (max_address <= pool_size) { + return true; } + return false; +} - /*! - * \brief A helper function check whether a offset is valid given the constraints - */ - bool IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset, - const size_t& size_bytes) { - if (candidate_pool->size_hint_bytes == -1) { - // this means pool is not bounded - return true; - } - auto pool_size = static_cast(candidate_pool->size_hint_bytes->value); - auto max_address = next_offset + size_bytes; - if (max_address <= pool_size) { - return true; +/*! + * \brief Selects a pool for placement in the given set of ordered pool candidates + */ +PoolInfo GreedyBase::SelectPlacementPool( + const BufferInfo& buf_info, + const std::unordered_map& pool_offsets) { + // Here the pool candidates are ordered when it is consumed by the algorithm. + // This could be from order the user has specified. However, schedulers are + // welcome to change the order for performance reasons. + for (const auto& pool_info : buf_info->pool_candidates) { + if (pool_offsets.count(pool_info)) { + return pool_info; } - return false; } + CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " + "trying to allocate the buffer : " + << buf_info << "\n. Please increase the size_hints for memory pools."; + return PoolInfo(); +} - /*! - * \brief Selects a pool for placement in the given set of ordered pool candidates - */ - PoolInfo SelectPlacementPool( - const BufferInfo& buf_info, - const std::unordered_map& pool_offsets) { - // Here the pool candidates are ordered when it is consumed by the algorithm. - // This could be from order the user has specified. However, schedulers are - // welcome to change the order for performance reasons. +/*! + * \brief This is the base allocation function that works on sorted BufferInfo objects based + * on the greedy heuristic. The sorting algorithm has to be called before calling this. + */ +Map GreedyBase::PostSortAllocation( + const std::vector& buffer_info_vec) { + Map pool_allocations; + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map pool_offset_candidates; for (const auto& pool_info : buf_info->pool_candidates) { - if (pool_offsets.count(pool_info)) { - return pool_info; + // Mark pool candidates that satisfy the size constraints. + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; } } - CHECK(false) << "TVM USMP Error: the space available in the provided pools exceeded when " - "trying to allocate the buffer : " - << buf_info << "\n. Please increase the size_hints for memory pools."; - return PoolInfo(); - } - /*! - * \brief This is the base allocation function that works on sorted BufferInfo objects based - * on the greedy heuristic. The sorting algorithm has to be called before calling this. - */ - Map PostSortAllocation( - const std::vector& buffer_info_vec) { - Map pool_allocations; - for (const auto& buf_info : buffer_info_vec) { - std::unordered_map pool_offset_candidates; - for (const auto& pool_info : buf_info->pool_candidates) { - // Mark pool candidates that satisfy the size constraints. - if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { - pool_offset_candidates[pool_info] = 0; - } - } - - for (const auto& conflict_buf_info_obj : buf_info->conflicts) { - auto conflict_buf_info = Downcast(conflict_buf_info_obj); - size_t next_offset = 0; - // We only look at already allocated BufferInfo in-terms of conflicts. - if (pool_allocations.count(conflict_buf_info)) { - auto pool_allocation = pool_allocations[conflict_buf_info]; - next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; - next_offset = - round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); - // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. - if (IsValidPlacement(pool_allocation->pool_info, next_offset, - buf_info->size_bytes->value)) { - // There could be multiple conflicting BufferInfo in the same pool. - // Thus, we need to make sure we pick the largest offset of them all. - if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { - pool_offset_candidates[pool_allocation->pool_info] = next_offset; - } - } else { - pool_offset_candidates.erase(pool_allocation->pool_info); + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + auto conflict_buf_info = Downcast(conflict_buf_info_obj); + size_t next_offset = 0; + // We only look at already allocated BufferInfo in-terms of conflicts. + if (pool_allocations.count(conflict_buf_info)) { + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + // Checks whether the next offset in the same pool as the conflicting BufferInfo is valid. + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + // There could be multiple conflicting BufferInfo in the same pool. + // Thus, we need to make sure we pick the largest offset of them all. + if (next_offset > pool_offset_candidates[pool_allocation->pool_info]) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); } } - auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); - pool_allocations.Set( - buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); } - return pool_allocations; + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); + pool_allocations.Set( + buf_info, PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool]))); } -}; + return pool_allocations; +} /*! * \brief This class implements Greedy by the size of BufferInfo diff --git a/src/tir/usmp/algo/hill_climb.cc b/src/tir/usmp/algo/hill_climb.cc new file mode 100644 index 000000000000..c4ed73eb2feb --- /dev/null +++ b/src/tir/usmp/algo/hill_climb.cc @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tir/analysis/usmp/algo/hill_climb.cc + * \brief Implement greedy by size memory planning algorithm + */ +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace tvm { +namespace tir { +namespace usmp { +namespace algo { + +/* + * Simulated annealing / Hill climb + * + * Works by continiously invoking 'greedy-by-size' allocation, + * assessing the result, and introducing permutations to the allocation + * order which hopefully will led to more 'compact' memory allocation. + */ +class HillClimbAllocator : public GreedyBase { + private: + size_t memory_pressure_ = 0; + + public: + explicit HillClimbAllocator(size_t memory_pressure) + : GreedyBase(), memory_pressure_(memory_pressure) {} + + protected: + using alloc_map_t = std::unordered_map; + + /* + * Initial sorting routine + */ + void sort_vector(std::vector* buffer_info_vec) { + std::sort(buffer_info_vec->begin(), buffer_info_vec->end(), + [](const BufferInfo& a, const BufferInfo& b) { + if (a->size_bytes->value == b->size_bytes->value) { + if (a->conflicts.size() == b->conflicts.size()) { + return std::string(a->name_hint->data) > std::string(b->name_hint->data); + } else { + return a->conflicts.size() > b->conflicts.size(); + } + } + return a->size_bytes->value > b->size_bytes->value; + }); + } + + /* + * HillClimb's version of greedy allocation + * \param buffer_info_vec - buffers in specific order for allocation + */ + alloc_map_t greedy(const std::vector& buffer_info_vec) { + alloc_map_t pool_allocations(buffer_info_vec.size()); + for (const auto& buf_info : buffer_info_vec) { + std::unordered_map pool_offset_candidates; + for (const auto& pool_info : buf_info->pool_candidates) { + if (IsValidPlacement(pool_info, 0, buf_info->size_bytes->value)) { + pool_offset_candidates[pool_info] = 0; + } + } + + std::vector buf_conf; + for (const auto& conflict_buf_info_obj : buf_info->conflicts) { + const BufferInfoNode* conflict_buf_info = conflict_buf_info_obj.as(); + if (pool_allocations.end() != pool_allocations.find(conflict_buf_info)) { + buf_conf.push_back(conflict_buf_info); + } + } + + // extra sorting for pool offsets + std::sort(buf_conf.begin(), buf_conf.end(), + [&pool_allocations](const auto* a, const auto* b) { + return pool_allocations[a]->byte_offset->value < + pool_allocations[b]->byte_offset->value; + }); + + for (const auto* conflict_buf_info : buf_conf) { + size_t next_offset = 0; + auto pool_allocation = pool_allocations[conflict_buf_info]; + next_offset = pool_allocation->byte_offset + conflict_buf_info->size_bytes; + next_offset = round_up_to_byte_alignment(next_offset, conflict_buf_info->alignment->value); + if (!pool_offset_candidates.count(pool_allocation->pool_info)) { + continue; + } + if (IsValidPlacement(pool_allocation->pool_info, next_offset, + buf_info->size_bytes->value)) { + if (next_offset > pool_offset_candidates[pool_allocation->pool_info] && + pool_offset_candidates[pool_allocation->pool_info] + + static_cast(buf_info->size_bytes) > + static_cast(pool_allocation->byte_offset)) { + pool_offset_candidates[pool_allocation->pool_info] = next_offset; + } + } else { + pool_offset_candidates.erase(pool_allocation->pool_info); + } + } + auto selected_pool = SelectPlacementPool(buf_info, pool_offset_candidates); + pool_allocations[buf_info.as()] = + PoolAllocation(selected_pool, Integer(pool_offset_candidates[selected_pool])); + } + return pool_allocations; + } + + /* + * Finds highest allocated memory address for each pool + */ + std::unordered_map find_highest( + alloc_map_t* pool_allocations) { + std::unordered_map pool_sizes; + for (const auto& it : *pool_allocations) { + const BufferInfoNode* buf = it.first; + const PoolAllocation& pa = it.second; + size_t high_sz = pa->byte_offset + buf->size_bytes; + if (pool_sizes[pa->pool_info] <= high_sz) { + pool_sizes[pa->pool_info] = high_sz; + } + } + return pool_sizes; + } + + /* + * Collects lists of first and secind level neigbors for provided buf. + * First level are the immediate neighbors of the buf and + * second level are the immediate neighbors of the first level nodes + */ + template + void collect_neighbor_lists(const BufferInfoNode* buf, + std::vector* first_level, + std::vector* second_level, const TPos& _pos) { + std::unordered_map first_level_set; + std::unordered_map second_level_set; + + auto buf_pos = _pos(buf); + for (const auto& c1 : buf->conflicts) { + const auto* c1_buf = c1.as(); + int c1_pos = _pos(c1_buf); + if (buf_pos > c1_pos) { + first_level_set[c1_pos] = c1_buf; + } + int c2_pos = -1; + for (const auto& c2 : c1_buf->conflicts) { + const auto c2_buf = c2.as(); + if (c1_pos > (c2_pos = _pos(c2_buf))) { + second_level_set[c2_pos] = c2_buf; + } + } + } + + // std::vector first_level; + for (const auto& i : first_level_set) { + first_level->push_back(i.second); + } + // std::vector second_level; + for (const auto& i : second_level_set) { + second_level->push_back(i.second); + } + } + + public: + Map PlanMemory(const Array& buffer_info_arr) { +// rand_r does not exist on Windows platform +#if defined(__linux__) || defined(__ANDROID__) + unsigned int _seedp = 0; +#define rnd_func() rand_r(&_seedp) +#else +#define rnd_func() rand() +#endif + + std::vector buffer_info_vec; + for (const auto& buffer_info : buffer_info_arr) { + ICHECK(buffer_info->pool_candidates.size()) + << "Cannot process buffer \"" << buffer_info->name_hint << "\" with no pool candidates"; + buffer_info_vec.push_back(std::move(buffer_info)); + } + + sort_vector(&buffer_info_vec); + + // populate positional index map + std::unordered_map _pos_map; + for (size_t index = 0; index < buffer_info_vec.size(); ++index) { + _pos_map[buffer_info_vec[index].as()] = index; + } + + size_t total_size = 0; + int attempts = 0; + + int swap_i1 = -1; + int swap_i2 = -1; + size_t desired_bytes_ = memory_pressure_; + constexpr auto _max_attempts = 500; + alloc_map_t rollback_pool_allocations; + alloc_map_t result_pool_allocations; + alloc_map_t pool_allocations; + + auto swap_buffers = [&buffer_info_vec, &_pos_map](int i1, int i2) { + if (i1 == i2) return; + auto b1 = buffer_info_vec[i1]; + auto b2 = buffer_info_vec[i2]; + buffer_info_vec[i1] = b2; + buffer_info_vec[i2] = b1; + + _pos_map[b1.as()] = i2; + _pos_map[b2.as()] = i1; + }; + + auto _pos = [&_pos_map](const auto* e) { + auto it = _pos_map.find(e); + if (it != _pos_map.end()) { + return it->second; + } + LOG(FATAL) << "node is not indexed in the _pos_map"; + return -1; + }; + + for (; attempts < _max_attempts; ++attempts) { + rollback_pool_allocations = std::move(pool_allocations); + pool_allocations = std::move(greedy(buffer_info_vec)); + + // estimate result buffers + std::unordered_map pool_sizes = + find_highest(&pool_allocations); + // calculate summary + size_t total = 0; + for (const auto& el : pool_sizes) { + total += el.second; + } + // accept/reject result heuristic + if (!total_size || /* first run */ + (total_size > total || /* always accept if better or with some probability */ + rnd_func() % 100 < static_cast(50 * (total - total_size) / total / attempts))) { + // remember winning combination + result_pool_allocations = pool_allocations; + total_size = total; + + // reached desired size + if (total_size <= desired_bytes_) { + break; + } + + } else { + // rollback + swap_buffers(swap_i2, swap_i1); + pool_allocations = std::move(rollback_pool_allocations); + pool_sizes = find_highest(&pool_allocations); + } + + std::vector max_pool_buf; + + for (const auto& it : pool_allocations) { + const auto* buf = it.first; + const auto pa = it.second; + size_t high_sz = pa->byte_offset + buf->size_bytes; + if (pool_sizes[pa->pool_info] == high_sz) { + max_pool_buf.push_back(buf); + } + } + + // pick highest + const BufferInfoNode* node = max_pool_buf[rnd_func() % max_pool_buf.size()]; + std::vector first_level; + std::vector second_level; + collect_neighbor_lists(node, &first_level, &second_level, _pos); + + // retry if no first level neightbors were collected + if (!first_level.size()) { + continue; + } + + // pick the buffers + const BufferInfoNode* swap_buf1 = first_level[rnd_func() % first_level.size()]; + const BufferInfoNode* swap_buf2 = swap_buf1; + while (swap_buf2 == swap_buf1) { + swap_buf2 = second_level.size() && (!first_level.size() || (rnd_func() % 100 > 25)) + ? second_level[rnd_func() % second_level.size()] + : first_level[rnd_func() % first_level.size()]; + + if (second_level.size() < 2 && first_level.size() < 2) break; + } + if (swap_buf1 == swap_buf2) { + continue; + } + + swap_i1 = _pos(swap_buf1); + swap_i2 = _pos(swap_buf2); + // do swap + swap_buffers(swap_i1, swap_i2); + } + + Map result; + // return winning combination + for (auto it : result_pool_allocations) { + result.Set(GetRef(it.first), it.second); + } + return result; + } +}; + +Map HillClimb(const Array& buffer_info_arr, + const Integer& memory_pressure) { + return HillClimbAllocator(memory_pressure).PlanMemory(buffer_info_arr); +} + +TVM_REGISTER_GLOBAL("tir.usmp.algo.hill_climb") + .set_body_typed([](Array buffer_info_arr, Integer memory_pressure) { + return HillClimb(buffer_info_arr, memory_pressure); + }); + +} // namespace algo +} // namespace usmp +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_usmp_algo.py b/tests/python/unittest/test_tir_usmp_algo.py index 1a763d083b10..1995695100cb 100644 --- a/tests/python/unittest/test_tir_usmp_algo.py +++ b/tests/python/unittest/test_tir_usmp_algo.py @@ -123,7 +123,7 @@ def test_no_pool_error(): buffer_pool_allocations = fusmp_algo(buffer_info_arr, 0) -@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts"]) +@pytest.mark.parametrize("algorithm", ["greedy_by_size", "greedy_by_conflicts", "hill_climb"]) def test_name_based_ordering(algorithm): """ This checks when the size and conlicts are same a stable result is generated""" @@ -142,9 +142,9 @@ def _test(): bi_c = usmp_utils.BufferInfo( name_hint="bi_c", size_bytes=10, pool_candidates=[global_workspace_pool] ) - bi_a.set_conflicts([bi_b]) - bi_b.set_conflicts([bi_c]) - bi_c.set_conflicts([bi_a]) + bi_a.set_conflicts([bi_b, bi_c]) + bi_b.set_conflicts([bi_c, bi_a]) + bi_c.set_conflicts([bi_a, bi_b]) buffer_info_arr = [bi_a, bi_b, bi_c] fusmp_algo = tvm.get_global_func(f"tir.usmp.algo.{algorithm}") @@ -160,7 +160,7 @@ def _test(): @pytest.mark.parametrize( ["algorithm", "workspace_size"], - [("greedy_by_size", 140), ("greedy_by_conflicts", 140)], + [("greedy_by_size", 140), ("greedy_by_conflicts", 140), ("hill_climb", 140)], ) def test_linear(algorithm, workspace_size): """ @@ -222,7 +222,7 @@ def test_linear(algorithm, workspace_size): @pytest.mark.parametrize( ["algorithm", "workspace_size"], - [("greedy_by_size", 190), ("greedy_by_conflicts", 320)], + [("greedy_by_size", 190), ("greedy_by_conflicts", 320), ("hill_climb", 190)], ) def test_fanout(algorithm, workspace_size): """ @@ -364,7 +364,11 @@ def run_model(input: T.handle, output: T.handle) -> None: @pytest.mark.parametrize( ["algorithm", "fast_memory_size", "slow_memory_size"], - [("greedy_by_size", 200704, 1418528), ("greedy_by_conflicts", 200704, 1418528)], + [ + ("greedy_by_size", 200704, 1418528), + ("greedy_by_conflicts", 200704, 1418528), + ("hill_climb", 200704, 1117462), + ], ) def test_mobilenet_subgraph(algorithm, fast_memory_size, slow_memory_size): target = Target("c") @@ -529,7 +533,8 @@ def tvmgen_default_fused_nn_conv2d_add_fixed_point_multiply_clip_cast_cast(place @pytest.mark.parametrize( - ["algorithm", "workspace_size"], [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256)] + ["algorithm", "workspace_size"], + [("greedy_by_size", 7920256), ("greedy_by_conflicts", 7200256), ("hill_climb", 7200256)], ) def test_resnet_subgraph(algorithm, workspace_size): target = Target("c") diff --git a/tests/python/unittest/test_tir_usmp_algo_hill_climb.py b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py new file mode 100644 index 000000000000..a5f1158a90c1 --- /dev/null +++ b/tests/python/unittest/test_tir_usmp_algo_hill_climb.py @@ -0,0 +1,397 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import sys +import pytest +import random +import tvm +from tvm.tir.usmp.utils import BufferInfo, PoolInfo + + +def _check_max_workspace_size(buffer_pool_allocations, pool_info, size): + """Helper to check maximum allocated memory size""" + max_workspace_size = 0 + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + if pool_allocation.pool_info == pool_info: + size_candidate = pool_allocation.byte_offset + buffer_info.size_bytes + if size_candidate > max_workspace_size: + max_workspace_size = size_candidate + _diff = max_workspace_size.value - size + return ( + (max_workspace_size.value == size), + "'{}': expected {} got {}, diff {:0.2f}% ({} bytes)".format( + pool_info.pool_name, size, max_workspace_size, 100 * _diff / size, _diff + ), + ) + + +def _verify_conflicts(buffer_info, pool_allocation, buffer_info_map): + """Helper to check expected liveness conflicts""" + for conflict in buffer_info.conflicts: + conflict_pool_allocation = buffer_info_map[conflict] + + if conflict_pool_allocation.pool_info == pool_allocation.pool_info: + assert conflict_pool_allocation.byte_offset != pool_allocation.byte_offset + l2 = ( + max( + conflict_pool_allocation.byte_offset + conflict.size_bytes, + pool_allocation.byte_offset + buffer_info.size_bytes, + ) + - min(conflict_pool_allocation.byte_offset, pool_allocation.byte_offset) + ) + assert ( + conflict.size_bytes + buffer_info.size_bytes <= l2 + ), 'Conflicting: \n"{} @{}"\n"{} @{}"'.format( + conflict, conflict_pool_allocation, buffer_info, pool_allocation + ) + + +def _verify_all_conflicts(buffer_pool_allocations): + """Helper to verify liveness conflicts""" + for buffer_info, pool_allocation in buffer_pool_allocations.items(): + _verify_conflicts(buffer_info, pool_allocation, buffer_pool_allocations) + + +def test_bounded(random_len=150, pools=[PoolInfo("default", {}, 65535), PoolInfo("slow", {})]): + """Tests two pools, one is bounded and one is not limited""" + random.seed(0) + mem_range = [BufferInfo(str(i), random.randrange(1, 65535), pools) for i in range(random_len)] + for mr in mem_range: + pr = random.choice(mem_range) + while pr in (*mr.conflicts, mr): + pr = random.choice(mem_range) + + mr.set_conflicts([*mr.conflicts, pr]) + pr.set_conflicts([*pr.conflicts, mr]) + + fusmp_algo = tvm.get_global_func("tir.usmp.algo.hill_climb") + result_map = fusmp_algo(mem_range, 0) + _verify_all_conflicts(result_map) + + +def __test_data_alloc_max(): + """Test data""" + intervals = [ + (0, 159, 2048), + (0, 13, 7904), + (4, 35, 16), + (12, 17, 32768), + (16, 21, 32768), + ] + return intervals + + +def __test_data_deep_speech(): + """Test data""" + intervals = [ + (0, 159, 2048), + (0, 151, 2048), + (0, 13, 7904), + (2, 49, 16), + (4, 35, 16), + (6, 21, 16), + (12, 17, 32768), + (16, 21, 32768), + (20, 27, 32768), + (26, 31, 32768), + (30, 35, 32768), + (34, 41, 32768), + (40, 45, 32768), + (44, 49, 32768), + (48, 145, 32768), + (54, 59, 2048), + (58, 483, 4096), + (60, 65, 2048), + (64, 461, 4096), + (66, 71, 2048), + (70, 439, 4096), + (72, 77, 2048), + (76, 417, 4096), + (78, 83, 2048), + (82, 395, 4096), + (84, 89, 2048), + (88, 373, 4096), + (90, 95, 2048), + (94, 351, 4096), + (96, 101, 2048), + (100, 329, 4096), + (102, 107, 2048), + (106, 307, 4096), + (108, 113, 2048), + (112, 285, 4096), + (114, 119, 2048), + (118, 263, 4096), + (120, 125, 2048), + (124, 241, 4096), + (126, 131, 2048), + (130, 219, 4096), + (132, 137, 2048), + (136, 197, 4096), + (138, 143, 2048), + (142, 175, 4096), + (144, 149, 2048), + (148, 153, 4096), + (152, 163, 8192), + (154, 171, 2048), + (156, 181, 2048), + (160, 167, 2048), + (162, 165, 2048), + (168, 171, 2048), + (170, 509, 2048), + (174, 185, 8192), + (176, 193, 2048), + (178, 203, 2048), + (182, 189, 2048), + (184, 187, 2048), + (190, 193, 2048), + (192, 511, 2048), + (196, 207, 8192), + (198, 215, 2048), + (200, 225, 2048), + (204, 211, 2048), + (206, 209, 2048), + (212, 215, 2048), + (214, 513, 2048), + (218, 229, 8192), + (220, 237, 2048), + (222, 247, 2048), + (226, 233, 2048), + (228, 231, 2048), + (234, 237, 2048), + (236, 515, 2048), + (240, 251, 8192), + (242, 259, 2048), + (244, 269, 2048), + (248, 255, 2048), + (250, 253, 2048), + (256, 259, 2048), + (258, 517, 2048), + (262, 273, 8192), + (264, 281, 2048), + (266, 291, 2048), + (270, 277, 2048), + (272, 275, 2048), + (278, 281, 2048), + (280, 519, 2048), + (284, 295, 8192), + (286, 303, 2048), + (288, 313, 2048), + (292, 299, 2048), + (294, 297, 2048), + (300, 303, 2048), + (302, 521, 2048), + (306, 317, 8192), + (308, 325, 2048), + (310, 335, 2048), + (314, 321, 2048), + (316, 319, 2048), + (322, 325, 2048), + (324, 523, 2048), + (328, 339, 8192), + (330, 347, 2048), + (332, 357, 2048), + (336, 343, 2048), + (338, 341, 2048), + (344, 347, 2048), + (346, 525, 2048), + (350, 361, 8192), + (352, 369, 2048), + (354, 379, 2048), + (358, 365, 2048), + (360, 363, 2048), + (366, 369, 2048), + (368, 527, 2048), + (372, 383, 8192), + (374, 391, 2048), + (376, 401, 2048), + (380, 387, 2048), + (382, 385, 2048), + (388, 391, 2048), + (390, 529, 2048), + (394, 405, 8192), + (396, 413, 2048), + (398, 423, 2048), + (402, 409, 2048), + (404, 407, 2048), + (410, 413, 2048), + (412, 531, 2048), + (416, 427, 8192), + (418, 435, 2048), + (420, 445, 2048), + (424, 431, 2048), + (426, 429, 2048), + (432, 435, 2048), + (434, 533, 2048), + (438, 449, 8192), + (440, 457, 2048), + (442, 467, 2048), + (446, 453, 2048), + (448, 451, 2048), + (454, 457, 2048), + (456, 535, 2048), + (460, 471, 8192), + (462, 479, 2048), + (464, 489, 2048), + (468, 475, 2048), + (470, 473, 2048), + (476, 479, 2048), + (478, 537, 2048), + (482, 493, 8192), + (484, 501, 2048), + (486, 497, 2048), + (490, 497, 2048), + (492, 495, 2048), + (496, 626, 2048), + (498, 501, 2048), + (500, 626, 2048), + (504, 549, 16), + (508, 543, 32768), + (542, 549, 32768), + (548, 555, 32768), + (554, 563, 464), + (560, 563, 256), + (562, 617, 2048), + (564, 567, 1856), + (566, 573, 1024), + (568, 619, 1024), + (570, 573, 1024), + (572, 577, 1024), + (576, 579, 1024), + (578, 605, 1024), + (580, 593, 1024), + (584, 587, 1024), + (586, 603, 1024), + (594, 597, 1024), + (596, 613, 1024), + (604, 607, 1024), + (606, 617, 1024), + (616, 621, 2048), + (618, 621, 1024), + (620, 626, 464), + ] + return intervals + + +def __test_data_five(): + """Test data""" + return [ + (4, 5, 95), + (1, 4, 52135), + (3, 4, 12136), + (3, 5, 62099), + (4, 5, 50458), + ] + + +def __test_data_simple(): + """Test data""" + return [ + (0, 23, 131072), # 0 + (4, 5, 65568), # 1 + (4, 9, 8192), # 2 + (8, 30, 15360), # 3 + (10, 11, 65568), # 4 + (10, 15, 4096), # 5 + (16, 17, 65552), # 6 + (16, 21, 2048), # 7 + (22, 23, 32784), # 8 + (22, 27, 1024), # 9 + ] + + +def find_maximum_from_intervals(intervals): + """Expected list of intervals of (start, end, size)""" + sorted_list = sorted(intervals, key=lambda _: _[0]) + max_mem = 0 + for t in range(sorted_list[0][0], sorted_list[-1][1] + 1): + max_mem = max( + max_mem, sum([size for (start, end, size) in sorted_list if t >= start and t <= end]) + ) + return max_mem + + +@pytest.mark.parametrize( + "intervals", + [__test_data_alloc_max(), __test_data_simple(), __test_data_deep_speech(), __test_data_five()], +) +def test_intervals(intervals): + """Tests supplied intervals""" + random.seed(0) + result = run_intervals(intervals) + assert result["tir.usmp.algo.hill_climb"] == True, f" {result}" + + +def generate_range(sz, max_segment_sz=65535): + """Helper func to generate list of size sz of ranges of random size max_segment_sz""" + for i in range(0, sz): + start = random.randrange(i, sz) + stop = random.randrange(start + 1, start + 2 + ((sz - start) // 2)) + assert stop - start > 0 + yield (start, stop, random.randrange(1, max_segment_sz)) + + +def test_random_intervals(interval_len=16): + """Tests randomly generated interval of length interval_len""" + random.seed(0) + intervals = list(generate_range(interval_len)) + return run_intervals(intervals) + + +def run_intervals(intervals): + """Helper to run intervals""" + expected_mem = find_maximum_from_intervals(intervals) + pools = [PoolInfo("default", {})] + buffers = [] + # populate + for i, (start, stop, size) in enumerate(intervals): + buf = BufferInfo(str(i), size, pools) + # buf.set_pool_candidates( ["default"] ) + buffers.append(buf) + + # intersect + for i, (i_start, i_stop, _) in enumerate(intervals): + conflicts = set() + for j, (j_start, j_stop, _) in enumerate(intervals): + start = min(i_start, j_start) + stop = max(i_stop, j_stop) + i_dur = i_stop - i_start + 1 + j_dur = j_stop - j_start + 1 + + if i != j and (stop - start + 1 < i_dur + j_dur): + conflicts.add(buffers[j]) + + buffers[i].set_conflicts([c for c in sorted(conflicts, key=lambda c: c.name_hint)]) + + result = {} + for (alg, params) in [ + ("tir.usmp.algo.hill_climb", (expected_mem,)), + ("tir.usmp.algo.greedy_by_size", (expected_mem,)), + ]: + fusmp_algo = tvm.get_global_func(alg) + print("\n", "started", alg) + buffer_info_arr = fusmp_algo(buffers, *params) + print() + + _verify_all_conflicts(buffer_info_arr) + result[alg], msg = _check_max_workspace_size(buffer_info_arr, pools[0], expected_mem) + if not result[alg]: + print(alg, msg) + + return result + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))