Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 26 additions & 21 deletions include/tvm/ir/memory_pools.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,29 +103,34 @@ struct PoolInfoNode : public Object {
TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object);
};

/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/ir/memory_pools.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";

/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/ir/memory_pools.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";

/*! \brief The PoolSize is unrestricted for the memory planner */
static const int kUnrestrictedPoolSizeHint = -1;

/*! \brief The clock frequency is not known */
static const int kUnknownClockFrequency = -1;

/*! \brief The read bandwidth is not known */
static const int kUnknownReadBandwidth = -1;

/*! \brief The write bandwidth is not known */
static const int kUnknownWriteBandwidth = -1;

class PoolInfo : public ObjectRef {
public:
/*!
* \brief The string parameter to indicate read and write access to a pool
* This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in
* python/tvm/ir/memory_pools.py
*/
static constexpr const char* kTargetPoolReadWriteAccess = "rw";
/*!
* \brief The string parameter to indicate read only access to a pool
* This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in
* python/tvm/ir/memory_pools.py
*/
static constexpr const char* kTargetPoolReadOnlyAccess = "ro";
/*! \brief The PoolSize is unrestricted for the memory planner */
static const int kUnrestrictedPoolSizeHint = -1;
/*! \brief The clock frequency is not known */
static const int kUnknownClockFrequency = -1;
/*! \brief The read bandwidth is not known */
static const int kUnknownReadBandwidth = -1;
/*! \brief The write bandwidth is not known */
static const int kUnknownWriteBandwidth = -1;

TVM_DLL PoolInfo(String pool_name, Map<Target, String> target_access,
Integer size_hint_bytes = kUnrestrictedPoolSizeHint,
Integer clock_frequency_hz = kUnknownClockFrequency,
Expand Down
2 changes: 1 addition & 1 deletion src/tir/usmp/algo/greedy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ size_t GreedyBase::round_up_to_byte_alignment(const size_t& non_aligned_byte_off
*/
bool GreedyBase::IsValidPlacement(const PoolInfo& candidate_pool, const size_t& next_offset,
const size_t& size_bytes) {
if (candidate_pool->size_hint_bytes == PoolInfo::kUnrestrictedPoolSizeHint) {
if (candidate_pool->size_hint_bytes == kUnrestrictedPoolSizeHint) {
// this means pool is not bounded
return true;
}
Expand Down
9 changes: 4 additions & 5 deletions src/tir/usmp/transform/assign_pool_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ class PoolInfoAssigner : public StmtExprMutator {
ICHECK(target_host) << "main function does not have a target attr";
WorkspaceMemoryPools workspace_pools =
module->GetAttr<WorkspaceMemoryPools>(tvm::attr::kWorkspaceMemoryPools)
.value_or(WorkspaceMemoryPools({PoolInfo(
"global_workspace", {{target_host.value(), PoolInfo::kTargetPoolReadWriteAccess}},
PoolInfo::kUnrestrictedPoolSizeHint, PoolInfo::kUnknownClockFrequency,
PoolInfo::kUnknownReadBandwidth, PoolInfo::kUnknownWriteBandwidth, 0, 0,
{{target_host.value(), 1}}, Bool(true))}));
.value_or(WorkspaceMemoryPools(
{PoolInfo("global_workspace", {{target_host.value(), kTargetPoolReadWriteAccess}},
kUnrestrictedPoolSizeHint, kUnknownClockFrequency, kUnknownReadBandwidth,
kUnknownWriteBandwidth, 0, 0, {{target_host.value(), 1}}, Bool(true))}));
Array<PoolInfo> pool_infos = workspace_pools->pools;
for (const PoolInfo& pool_info : pool_infos) {
for (const auto& kv : pool_info->target_access) {
Expand Down