diff --git a/include/tvm/ir/memory_pools.h b/include/tvm/ir/memory_pools.h new file mode 100644 index 000000000000..c6e52648ebd4 --- /dev/null +++ b/include/tvm/ir/memory_pools.h @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ir/memory_pools.h + * \brief The object definition for relay.build argument type of memory pools + */ +#ifndef TVM_IR_MEMORY_POOLS_H_ +#define TVM_IR_MEMORY_POOLS_H_ + +#include +#include + +namespace tvm { + +/*! + * \brief Describes a pool of memory accessible by one or more targets. + */ +struct PoolInfoNode : public Object { + /*! \brief The name of the memory pool */ + String pool_name; + /*! \brief The expected size hint to be used by the allocator. + * The size_hint_bytes is set to kUnrestrictedPoolSizeHint + * to indicate the pool is not size restricted. + */ + Integer size_hint_bytes; + /*! \brief The accessibility from each Target */ + Map target_access; // 'rw' or 'ro' + /*! \brief The clock frequency of the memory in Hz */ + Integer clock_frequency_hz; + /*! \brief The read bandwidth in bytes/cycle */ + Integer read_bandwidth_bytes_per_cycle; + /*! \brief The write bandwidth in bytes/cycle */ + Integer write_bandwidth_bytes_per_cycle; + /*! \brief The read latency in cycles */ + Integer read_latency_cycles; + /*! \brief The write latency in cycles */ + Integer write_latency_cycles; + /*! \brief The burst length in bytes for each Target */ + Map target_burst_bytes; + /*! \brief Whether pool is internally generated. + * The internal pools will be generated as part of + * the entry point code generation of the executor + */ + bool is_internal = false; + + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("pool_name", &pool_name); + v->Visit("size_hint_bytes", &size_hint_bytes); + v->Visit("target_access", &target_access); + v->Visit("clock_frequency_hz", &clock_frequency_hz); + v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); + v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle); + v->Visit("read_latency_cycles", &read_latency_cycles); + v->Visit("write_latency_cycles", &write_latency_cycles); + v->Visit("target_burst_bytes", &target_burst_bytes); + v->Visit("is_internal", &is_internal); + } + + bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { + return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && + equal(target_access, other->target_access) && + equal(target_access, other->target_access) && + equal(clock_frequency_hz, other->clock_frequency_hz) && + equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) && + equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) && + equal(read_latency_cycles, other->read_latency_cycles) && + equal(write_latency_cycles, other->write_latency_cycles) && + equal(target_burst_bytes, other->target_burst_bytes) && + equal(is_internal, other->is_internal); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(pool_name); + hash_reduce(size_hint_bytes); + hash_reduce(target_access); + hash_reduce(clock_frequency_hz); + hash_reduce(read_bandwidth_bytes_per_cycle); + hash_reduce(write_bandwidth_bytes_per_cycle); + hash_reduce(read_latency_cycles); + hash_reduce(write_latency_cycles); + hash_reduce(target_burst_bytes); + hash_reduce(is_internal); + } + + static constexpr const char* _type_key = "ir.PoolInfo"; + TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); +}; + +class PoolInfo : public ObjectRef { + public: + /*! + * \brief The string parameter to indicate read and write access to a pool + * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in + * python/tvm/ir/memory_pools.py + */ + static constexpr const char* kTargetPoolReadWriteAccess = "rw"; + /*! + * \brief The string parameter to indicate read only access to a pool + * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in + * python/tvm/ir/memory_pools.py + */ + static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; + /*! \brief The PoolSize is unrestricted for the memory planner */ + static const int kUnrestrictedPoolSizeHint = -1; + /*! \brief The clock frequency is not known */ + static const int kUnknownClockFrequency = -1; + /*! \brief The read bandwidth is not known */ + static const int kUnknownReadBandwidth = -1; + /*! \brief The write bandwidth is not known */ + static const int kUnknownWriteBandwidth = -1; + + TVM_DLL PoolInfo(String pool_name, Map target_access, + Integer size_hint_bytes = kUnrestrictedPoolSizeHint, + Integer clock_frequency_hz = kUnknownClockFrequency, + Integer read_bandwidth_bytes_per_cycle = kUnknownReadBandwidth, + Integer write_bandwidth_bytes_per_cycle = kUnknownWriteBandwidth, + Integer read_latency_cycles = 0, Integer write_latency_cycles = 0, + Map target_burst_bytes = {}, Bool is_internal = Bool(false)); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); +}; + +struct WorkspaceMemoryPoolsNode : public Object { + Array pools; + + void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pools", &pools); } + + bool SEqualReduce(const WorkspaceMemoryPoolsNode* other, SEqualReducer equal) const { + return equal(pools, other->pools); + } + + void SHashReduce(SHashReducer hash_reduce) const { hash_reduce(pools); } + + static constexpr const char* _type_key = "ir.WorkspaceMemoryPools"; + TVM_DECLARE_FINAL_OBJECT_INFO(WorkspaceMemoryPoolsNode, Object); +}; + +class WorkspaceMemoryPools : public ObjectRef { + public: + TVM_DLL WorkspaceMemoryPools(Array pools); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(WorkspaceMemoryPools, ObjectRef, WorkspaceMemoryPoolsNode); +}; + +} // namespace tvm + +#endif // TVM_IR_MEMORY_POOLS_H_ diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h index ec8c9b6c4b2c..d308db87af8b 100644 --- a/include/tvm/ir/module.h +++ b/include/tvm/ir/module.h @@ -495,6 +495,15 @@ constexpr const char* kExecutor = "executor"; */ constexpr const char* kRuntime = "runtime"; +/*! + * \brief workspace memory pools of the module + * + * Type: WorkspaceMemoryPools + * + * \sa tvm::WorkspaceMemoryPools + */ +constexpr const char* kWorkspaceMemoryPools = "workspace_memory_pools"; + } // namespace attr } // namespace tvm #endif // TVM_IR_MODULE_H_ diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h index 9ebe7f29b1f4..aa4f27de6d8b 100644 --- a/include/tvm/tir/usmp/utils.h +++ b/include/tvm/tir/usmp/utils.h @@ -26,6 +26,7 @@ #define TVM_TIR_USMP_UTILS_H_ #include +#include #include #include #include @@ -44,111 +45,6 @@ constexpr const char* kUSMPAlgorithmOption = "tir.usmp.algorithm"; namespace tir { namespace usmp { -/*! - * \brief Describes a pool of memory accessible by one or more targets. - */ -struct PoolInfoNode : public Object { - /*! \brief The name of the memory pool */ - String pool_name; - /*! \brief The expected size hint to be used by the allocator. - * The size_hint_bytes is set to kUnrestrictedPoolSizeHint - * to indicate the pool is not size restricted. - */ - Integer size_hint_bytes; - /*! \brief The accessibility from each Target */ - Map target_access; // 'rw' or 'ro' - /*! \brief The clock frequency of the memory in Hz */ - Integer clock_frequency_hz; - /*! \brief The read bandwidth in bytes/cycle */ - Integer read_bandwidth_bytes_per_cycle; - /*! \brief The write bandwidth in bytes/cycle */ - Integer write_bandwidth_bytes_per_cycle; - /*! \brief The read latency in cycles */ - Integer read_latency_cycles; - /*! \brief The write latency in cycles */ - Integer write_latency_cycles; - /*! \brief The burst length in bytes for each Target */ - Map target_burst_bytes; - /*! \brief Whether pool is internally generated. - * The internal pools will be generated as part of - * the entry point code generation of the executor - */ - bool is_internal = false; - - void VisitAttrs(tvm::AttrVisitor* v) { - v->Visit("pool_name", &pool_name); - v->Visit("size_hint_bytes", &size_hint_bytes); - v->Visit("target_access", &target_access); - v->Visit("clock_frequency_hz", &clock_frequency_hz); - v->Visit("read_bandwidth_bytes_per_cycle", &read_bandwidth_bytes_per_cycle); - v->Visit("write_bandwidth_bytes_per_cycle", &write_bandwidth_bytes_per_cycle); - v->Visit("read_latency_cycles", &read_latency_cycles); - v->Visit("write_latency_cycles", &write_latency_cycles); - v->Visit("target_burst_bytes", &target_burst_bytes); - v->Visit("is_internal", &is_internal); - } - - bool SEqualReduce(const PoolInfoNode* other, SEqualReducer equal) const { - return equal(pool_name, other->pool_name) && equal(size_hint_bytes, other->size_hint_bytes) && - equal(target_access, other->target_access) && - equal(target_access, other->target_access) && - equal(clock_frequency_hz, other->clock_frequency_hz) && - equal(read_bandwidth_bytes_per_cycle, other->read_bandwidth_bytes_per_cycle) && - equal(write_bandwidth_bytes_per_cycle, other->write_bandwidth_bytes_per_cycle) && - equal(read_latency_cycles, other->read_latency_cycles) && - equal(write_latency_cycles, other->write_latency_cycles) && - equal(target_burst_bytes, other->target_burst_bytes) && - equal(is_internal, other->is_internal); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(pool_name); - hash_reduce(size_hint_bytes); - hash_reduce(target_access); - hash_reduce(clock_frequency_hz); - hash_reduce(read_bandwidth_bytes_per_cycle); - hash_reduce(write_bandwidth_bytes_per_cycle); - hash_reduce(read_latency_cycles); - hash_reduce(write_latency_cycles); - hash_reduce(target_burst_bytes); - hash_reduce(is_internal); - } - - static constexpr const char* _type_key = "tir.usmp.PoolInfo"; - TVM_DECLARE_FINAL_OBJECT_INFO(PoolInfoNode, Object); -}; - -class PoolInfo : public ObjectRef { - public: - /*! - * \brief The string parameter to indicate read and write access to a pool - * This needs to be kept in sync with PoolInfo.READ_WRITE_ACCESS in - * python/tvm/tir/usmp/utils.py - */ - static constexpr const char* kTargetPoolReadWriteAccess = "rw"; - /*! - * \brief The string parameter to indicate read only access to a pool - * This needs to be kept in sync with PoolInfo.READ_ONLY_ACCESS in - * python/tvm/tir/usmp/utils.py - */ - static constexpr const char* kTargetPoolReadOnlyAccess = "ro"; - /*! \brief The PoolSize is unrestricted for the memory planner */ - static const int kUnrestrictedPoolSizeHint = -1; - /*! \brief The clock frequency is not known */ - static const int kUnknownClockFrequency = -1; - /*! \brief The read bandwidth is not known */ - static const int kUnknownReadBandwidth = -1; - /*! \brief The write bandwidth is not known */ - static const int kUnknownWriteBandwidth = -1; - - TVM_DLL PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, - Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, - Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, - Integer write_latency_cycles, Map target_burst_bytes, - Bool is_internal); - TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(PoolInfo, ObjectRef, PoolInfoNode); -}; - /*! * \brief Describes an abstract memory buffer that will get allocated inside a pool. * The actual memory buffer in represented by PoolAllocationNode after static memory planning. diff --git a/python/tvm/__init__.py b/python/tvm/__init__.py index 3f2f277d2926..ac3acdde3088 100644 --- a/python/tvm/__init__.py +++ b/python/tvm/__init__.py @@ -43,6 +43,8 @@ from .ir import transform from .ir import instrument from .ir import container +from .ir import PoolInfo +from .ir import WorkspaceMemoryPools from . import ir # tvm.tir diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py index 83557a3eae19..928631ce10de 100644 --- a/python/tvm/ir/__init__.py +++ b/python/tvm/ir/__init__.py @@ -30,6 +30,7 @@ from .module import IRModule from .attrs import Attrs, DictAttrs, make_node from .container import Array, Map +from .memory_pools import PoolInfo, WorkspaceMemoryPools from . import transform from . import instrument diff --git a/python/tvm/ir/memory_pools.py b/python/tvm/ir/memory_pools.py new file mode 100644 index 000000000000..6fa6bb41280e --- /dev/null +++ b/python/tvm/ir/memory_pools.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Objects for Memory Pools to be used within the compilation""" + +from typing import Optional, List + +from tvm._ffi import register_object +from tvm.runtime import Object +from . import _ffi_api + + +@register_object("ir.PoolInfo") +class PoolInfo(Object): + """PoolInfo object holds information related to memory pools + where the statically sized allocate nodes will pooled into. + + Parameters + ---------- + pool_name : str + The name of the memory pool + + target_access : Dict[Target, str] + A dictionary where keys describe which targets could + access the pool where value could take the values : + a) "rw" : read-write access + b) "ro" : write-only acesss + + size_hint_bytes : Optional[int] + The expected size hint to be used by the allocator. + The default value would be -1 which means the pool + is not size restricted. + + clock_frequency_hz : Optional[int] + The clock frequency that the memory pool runs at in Hz. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + read_bandwidth_bytes_per_cycle : Optional[int] + The read bandwidth of the memory pool in bytes/cycle. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + write_bandwidth_bytes_per_cycle : Optional[int] + The write bandwidth of the memory pool in bytes/cycle. + If not specified/known, this will default to -1 indicating + it hasn't been defined. + + read_latency_cycles : Optional[int] + The read latency of the memory pool in cycles. + If not specified/known, this will default to 0. + + write_latency_cycles : Optional[int] + The write latency of the memory pool in cycles. + If not specified/known, this will default to 0. + + target_burst_bytes : Optional[Union[Dict[Target, int], None]] + The burst length of the memory pool in bytes per target. + If not specified/known for a given target, a burst length + of 1 byte will be assumed. + + """ + + # The string parameter to indicate read and write access to a pool + # This needs to be kept in sync with kTargetPoolReadWriteAccess in + # include/tvm/ir/memory_pools.h + READ_WRITE_ACCESS = "rw" + # The string parameter to indicate read only access to a pool + # This needs to be kept in sync with kTargetPoolReadOnlyAccess in + # include/tvm/ir/memory_pools.h + READ_ONLY_ACCESS = "ro" + + def __init__( + self, + pool_name: str, + target_access, # Dict[Target, str] + size_hint_bytes: Optional[int] = -1, + clock_frequency_hz: Optional[int] = -1, + read_bandwidth_bytes_per_cycle: Optional[int] = -1, + write_bandwidth_bytes_per_cycle: Optional[int] = -1, + read_latency_cycles: Optional[int] = 0, + write_latency_cycles: Optional[int] = 0, + target_burst_bytes=None, # Optional[Union[Dict[target.Target, int], None]] + ): + if not target_burst_bytes: + target_burst_bytes = dict() + + self.__init_handle_by_constructor__( + _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member + pool_name, + target_access, + size_hint_bytes, + clock_frequency_hz, + read_bandwidth_bytes_per_cycle, + write_bandwidth_bytes_per_cycle, + read_latency_cycles, + write_latency_cycles, + target_burst_bytes, + ) + + +@register_object("ir.WorkspaceMemoryPools") +class WorkspaceMemoryPools(Object): + """This object contains a list of PoolInfo objects to be used as + workspace memory in the compilation + + Parameters + ---------- + pools : List[PoolInfo] + The list of PoolInfo objects to be used with the compilation + """ + + def __init__( + self, + pools: List[PoolInfo], + ): + self.__init_handle_by_constructor__( + _ffi_api.WorkspaceMemoryPools, pools # type: ignore # pylint: disable=no-member + ) diff --git a/python/tvm/micro/model_library_format.py b/python/tvm/micro/model_library_format.py index 79866fe08399..3f64154b4edf 100644 --- a/python/tvm/micro/model_library_format.py +++ b/python/tvm/micro/model_library_format.py @@ -47,14 +47,16 @@ class UnsupportedInModelLibraryFormatError(Exception): def generate_c_interface_header( - module_name, inputs, outputs, devices, workspace_size, include_path + module_name, inputs, outputs, pools, devices, workspace_size, include_path ): """Generate C Interface header to be included in MLF""" mangled_name = to_c_variable_style(prefix_generated_name(module_name)) metadata_header = os.path.join(include_path, f"{mangled_name}.h") interface_c_create = tvm._ffi.get_global_func("runtime.InterfaceCCreate") - interface_c_module = interface_c_create(module_name, inputs, outputs, devices, workspace_size) + interface_c_module = interface_c_create( + module_name, inputs, outputs, pools, devices, workspace_size + ) with open(metadata_header, "w") as header_file: header_file.write(interface_c_module.get_source()) @@ -275,6 +277,10 @@ def _get_inputs_and_outputs_from_module(mod): return inputs, outputs +def _get_pools_from_module(mod): + return list(dict(mod.executor_codegen_metadata.pool_inputs).values()) + + def _should_generate_interface_header(mod): return "interface-api" in mod.executor and mod.executor["interface-api"] == "c" @@ -344,9 +350,10 @@ def _export_graph_model_library_format( include_path.mkdir() inputs, outputs = _get_inputs_and_outputs_from_module(mod) devices = mod.get_devices() + pools = _get_pools_from_module(mod) workspace_size = int(metadata["memory"]["functions"]["main"][0]["workspace_size_bytes"]) generate_c_interface_header( - mod.libmod_name, inputs, outputs, devices, workspace_size, include_path + mod.libmod_name, inputs, outputs, pools, devices, workspace_size, include_path ) parameters_dir = tempdir / "parameters" diff --git a/python/tvm/relay/backend/executor_factory.py b/python/tvm/relay/backend/executor_factory.py index 68a76588440f..9ff7a7a8120b 100644 --- a/python/tvm/relay/backend/executor_factory.py +++ b/python/tvm/relay/backend/executor_factory.py @@ -106,6 +106,7 @@ def __init__( libmod_name, params, function_metadata, + executor_codegen_metadata, devices, ): self.ir_mod = ir_mod @@ -118,6 +119,7 @@ def __init__( self.params = params self.iter_cnt = 0 self.function_metadata = function_metadata + self.executor_codegen_metadata = executor_codegen_metadata self.devices = devices def get_devices(self): diff --git a/python/tvm/relay/build_module.py b/python/tvm/relay/build_module.py index d242a1ca07f3..5cfd3a16c3bc 100644 --- a/python/tvm/relay/build_module.py +++ b/python/tvm/relay/build_module.py @@ -101,6 +101,7 @@ def __init__(self): self._set_params_func = self.mod["set_params"] self._get_params_func = self.mod["get_params"] self._get_function_metadata = self.mod["get_function_metadata"] + self._get_executor_codegen_metadata = self.mod["get_executor_codegen_metadata"] self._get_devices = self.mod["get_devices"] self._get_irmodule = self.mod["get_irmodule"] @@ -111,6 +112,7 @@ def build( target_host=None, executor=Executor("graph"), runtime=Runtime("cpp"), + workspace_memory_pools=None, params=None, mod_name=None, ): @@ -142,6 +144,11 @@ def build( Runtime configuration to use when building the model. Defaults to "cpp" if no runtime specified. + workspace_memory_pools : Optional[WorkspaceMemoryPools] + The object that contains an Array of PoolInfo objects + that hold properties of workspace pools that could be + used by the inference. + params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. @@ -186,7 +193,7 @@ def build( mod_name = mangle_module_name(mod_name) - self._build(mod, target, target_host, executor, runtime, mod_name) + self._build(mod, target, target_host, executor, runtime, workspace_memory_pools, mod_name) autotvm.GLOBAL_SCOPE.silent = old_autotvm_silent # Get artifacts @@ -248,6 +255,12 @@ def get_function_metadata(self): each PrimFunc""" return self._get_function_metadata() + def get_executor_codegen_metadata(self): + """Return the metadata produced after executor + codegen + """ + return self._get_executor_codegen_metadata() + def get_devices(self): """Returns a list of devices configured in this module""" return self._get_devices() @@ -349,6 +362,7 @@ def build( target_host=None, executor=Executor("graph"), runtime=Runtime("cpp"), + workspace_memory_pools=None, params=None, mod_name="default", ): @@ -382,6 +396,11 @@ def build( Runtime configuration to use when building the model. Defaults to "cpp" if no runtime specified. + workspace_memory_pools : Optional[WorkspaceMemoryPools] + The object that contains an Array of PoolInfo objects + that hold properties of workspace pools that could be + used by the inference. + params : dict of str to NDArray Input parameters to the graph that do not change during inference time. Used for constant folding. @@ -452,11 +471,13 @@ def build( params=params, executor=executor, runtime=runtime, + workspace_memory_pools=workspace_memory_pools, mod_name=mod_name, ) func_metadata = bld_mod.get_function_metadata() devices = bld_mod.get_devices() lowered_ir_mods = bld_mod.get_irmodule() + executor_codegen_metadata = bld_mod.get_executor_codegen_metadata() if str(executor) == "aot": executor_factory = _executor_factory.AOTExecutorFactoryModule( @@ -469,6 +490,7 @@ def build( mod_name, params, func_metadata, + executor_codegen_metadata, devices, ) elif str(executor) == "graph": diff --git a/python/tvm/tir/usmp/utils.py b/python/tvm/tir/usmp/utils.py index d138238ad888..a7221cfe6f8e 100644 --- a/python/tvm/tir/usmp/utils.py +++ b/python/tvm/tir/usmp/utils.py @@ -17,12 +17,12 @@ """USMP Utilities and Data Structures""" # pylint: disable=invalid-name -from typing import Dict, Optional, List, Union +from typing import Optional, List from tvm._ffi import register_object from tvm.runtime import Object -from tvm.target import Target from . import _ffi_api +from ...ir.memory_pools import PoolInfo # The allocate node attribute to indicate candidate memory pools. @@ -31,95 +31,6 @@ CANDIDATE_MEMORY_POOL_ATTR = "candidate_memory_pools" -@register_object("tir.usmp.PoolInfo") -class PoolInfo(Object): - """PoolInfo object holds information related to memory pools - where the statically sized allocate nodes will pooled into. - - Parameters - ---------- - pool_name : str - The name of the memory pool - - target_access : Dict[Target, str] - A dictionary where keys describe which targets could - access the pool where value could take the values : - a) "rw" : read-write access - b) "ro" : write-only acesss - - size_hint_bytes : Optional[int] - The expected size hint to be used by the allocator. - The default value would be -1 which means the pool - is not size restricted. - - clock_frequency_hz : Optional[int] - The clock frequency that the memory pool runs at in Hz. - If not specified/known, this will default to -1 indicating - it hasn't been defined. - - read_bandwidth_bytes_per_cycle : Optional[int] - The read bandwidth of the memory pool in bytes/cycle. - If not specified/known, this will default to -1 indicating - it hasn't been defined. - - write_bandwidth_bytes_per_cycle : Optional[int] - The write bandwidth of the memory pool in bytes/cycle. - If not specified/known, this will default to -1 indicating - it hasn't been defined. - - read_latency_cycles : Optional[int] - The read latency of the memory pool in cycles. - If not specified/known, this will default to 0. - - write_latency_cycles : Optional[int] - The write latency of the memory pool in cycles. - If not specified/known, this will default to 0. - - target_burst_bytes : Optional[Union[Dict[Target, int], None]] - The burst length of the memory pool in bytes per target. - If not specified/known for a given target, a burst length - of 1 byte will be assumed. - - """ - - # The string parameter to indicate read and write access to a pool - # This needs to be kept in sync with kTargetPoolReadWriteAccess in - # include/tvm/tir/usmp/utils.h - READ_WRITE_ACCESS = "rw" - # The string parameter to indicate read only access to a pool - # This needs to be kept in sync with kTargetPoolReadOnlyAccess in - # include/tvm/tir/usmp/utils.h - READ_ONLY_ACCESS = "ro" - - def __init__( - self, - pool_name: str, - target_access: Dict[Target, str], - size_hint_bytes: Optional[int] = -1, - clock_frequency_hz: Optional[int] = -1, - read_bandwidth_bytes_per_cycle: Optional[int] = -1, - write_bandwidth_bytes_per_cycle: Optional[int] = -1, - read_latency_cycles: Optional[int] = 0, - write_latency_cycles: Optional[int] = 0, - target_burst_bytes: Optional[Union[Dict[Target, int], None]] = None, - ): - if not target_burst_bytes: - target_burst_bytes = dict() - - self.__init_handle_by_constructor__( - _ffi_api.PoolInfo, # type: ignore # pylint: disable=no-member - pool_name, - target_access, - size_hint_bytes, - clock_frequency_hz, - read_bandwidth_bytes_per_cycle, - write_bandwidth_bytes_per_cycle, - read_latency_cycles, - write_latency_cycles, - target_burst_bytes, - ) - - @register_object("tir.usmp.BufferInfo") class BufferInfo(Object): """BufferInfo object holds information related to buffers diff --git a/src/ir/memory_pools.cc b/src/ir/memory_pools.cc new file mode 100644 index 000000000000..5cf0035c90b2 --- /dev/null +++ b/src/ir/memory_pools.cc @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/ir/memory_pools.cc + * \brief The object definition for relay.build argument type of memory pools + */ + +#include +#include + +namespace tvm { + +PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, + Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes, + Bool is_internal) { + auto poolinfo_node = make_object(); + poolinfo_node->pool_name = pool_name; + poolinfo_node->size_hint_bytes = size_hint_bytes; + poolinfo_node->target_access = target_access; + poolinfo_node->clock_frequency_hz = clock_frequency_hz; + poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle; + poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle; + poolinfo_node->read_latency_cycles = read_latency_cycles; + poolinfo_node->write_latency_cycles = write_latency_cycles; + poolinfo_node->target_burst_bytes = target_burst_bytes; + poolinfo_node->is_internal = is_internal; + data_ = std::move(poolinfo_node); +} + +TVM_REGISTER_NODE_TYPE(PoolInfoNode); +TVM_REGISTER_GLOBAL("ir.PoolInfo") + .set_body_typed([](String pool_name, Map target_access, Integer size_hint_bytes, + Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, + Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, + Integer write_latency_cycles, Map target_burst_bytes) { + return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz, + read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle, + read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false)); + }); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "PoolInfoNode(\n" + << " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access + << ",\n size_hint_bytes=" << node->size_hint_bytes + << ",\n clock_frequency_hz=" << node->clock_frequency_hz + << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle + << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle + << ",\n read_latency_cycles=" << node->read_latency_cycles + << ",\n write_latency_cycles=" << node->write_latency_cycles + << ",\n target_burst_bytes=" << node->target_burst_bytes << ")"; + }); + +WorkspaceMemoryPools::WorkspaceMemoryPools(Array pools) { + auto workspace_memory_pools_node = make_object(); + workspace_memory_pools_node->pools = pools; + data_ = std::move(workspace_memory_pools_node); +} + +TVM_REGISTER_NODE_TYPE(WorkspaceMemoryPoolsNode); +TVM_REGISTER_GLOBAL("ir.WorkspaceMemoryPools").set_body_typed([](Array pools) { + return WorkspaceMemoryPools(pools); +}); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { + auto* node = static_cast(ref.get()); + p->stream << "WorkspaceMemoryPoolsNode(\n" + << "pools=" << node->pools << ")"; + }); + +} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index f076efeb4ac5..a14cef98669a 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -1021,7 +1021,7 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->codegen_->ListDevices(); }); - } else if (name == "get_metadata") { + } else if (name == "get_executor_codegen_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = output_.metadata; }); } else { diff --git a/src/relay/backend/build_module.cc b/src/relay/backend/build_module.cc index 2f986669e758..aa9c084de4f7 100644 --- a/src/relay/backend/build_module.cc +++ b/src/relay/backend/build_module.cc @@ -23,6 +23,7 @@ */ #include #include +#include #include #include #include @@ -103,8 +104,8 @@ struct ExecutorCodegen { Array ListDevices() { return CallFunc>("get_devices"); } - relay::backend::ExecutorCodegenMetadata GetMetadata() { - return CallFunc("get_metadata"); + relay::backend::ExecutorCodegenMetadata GetExecutorCodegenMetadata() { + return CallFunc("get_executor_codegen_metadata"); } virtual ~ExecutorCodegen() {} @@ -188,8 +189,8 @@ class RelayBuildModule : public runtime::ModuleNode { [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->GetModule(); }); } else if (name == "build") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { - ICHECK_EQ(args.num_args, 6); - this->Build(args[0], args[1], args[2], args[3], args[4], args[5]); + ICHECK_EQ(args.num_args, 7); + this->Build(args[0], args[1], args[2], args[3], args[4], args[5], args[6]); }); } else if (name == "list_params") { return PackedFunc( @@ -220,6 +221,10 @@ class RelayBuildModule : public runtime::ModuleNode { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->executor_codegen_->GetFunctionMetadata(); }); + } else if (name == "get_executor_codegen_metadata") { + return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { + *rv = this->executor_codegen_->GetExecutorCodegenMetadata(); + }); } else if (name == "optimize") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2); @@ -297,10 +302,12 @@ class RelayBuildModule : public runtime::ModuleNode { * \param mod_name Name of the module */ void Build(IRModule mod, const TargetMap& targets, const tvm::Target& target_host, - const Executor& executor, const Runtime& runtime, const String mod_name) { + const Executor& executor, const Runtime& runtime, + const WorkspaceMemoryPools& workspace_memory_pools, const String mod_name) { VLOG_CONTEXT << "Build"; executor_ = executor; runtime_ = runtime; + workspace_memory_pools_ = workspace_memory_pools; config_ = CompilationConfig(PassContext::Current(), targets, target_host); BuildRelay(std::move(mod), mod_name); } @@ -408,8 +415,10 @@ class RelayBuildModule : public runtime::ModuleNode { // Instead of recreating the IRModule, we should look at the differences between this and the // incoming IRModule to see if we can just pass (IRModule, Function) to the code generator. Function func = Downcast(relay_module->Lookup("main")); - IRModule func_module = WithAttrs(IRModule::FromExpr(func), {{tvm::attr::kExecutor, executor_}, - {tvm::attr::kRuntime, runtime_}}); + IRModule func_module = WithAttrs(IRModule::FromExpr(func), + {{tvm::attr::kExecutor, executor_}, + {tvm::attr::kRuntime, runtime_}, + {tvm::attr::kWorkspaceMemoryPools, workspace_memory_pools_}}); // Generate code for the updated function. executor_codegen_ = MakeExecutorCodegen(executor_->name); @@ -470,8 +479,9 @@ class RelayBuildModule : public runtime::ModuleNode { } auto ext_mods = executor_codegen_->GetExternalModules(); - ret_.mod = tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, - runtime_, executor_codegen_->GetMetadata()); + ret_.mod = + tvm::codegen::CreateMetadataModule(ret_.params, ret_.mod, ext_mods, host_target, runtime_, + executor_codegen_->GetExecutorCodegenMetadata()); // Remove external params which were stored in metadata module. for (tvm::runtime::Module mod : ext_mods) { auto pf_var = mod.GetFunction("get_const_vars"); @@ -493,6 +503,8 @@ class RelayBuildModule : public runtime::ModuleNode { Executor executor_; /*! \brief Runtime to codegen for */ Runtime runtime_; + /*! \brief Workspace memory pools to codegen for */ + WorkspaceMemoryPools workspace_memory_pools_; /*! \brief parameters */ std::unordered_map params_; /*! \brief building output */ diff --git a/src/relay/backend/graph_executor_codegen.cc b/src/relay/backend/graph_executor_codegen.cc index f61fe9b402b3..227f7bbfdf31 100644 --- a/src/relay/backend/graph_executor_codegen.cc +++ b/src/relay/backend/graph_executor_codegen.cc @@ -695,7 +695,7 @@ class GraphExecutorCodegenModule : public runtime::ModuleNode { }); } else if (name == "get_devices") { return PackedFunc([sptr_to_self](TVMArgs args, TVMRetValue* rv) { *rv = Array(); }); - } else if (name == "get_metadata") { + } else if (name == "get_executor_codegen_metadata") { return PackedFunc( [sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { *rv = this->output_.metadata; }); } else if (name == "get_function_metadata") { diff --git a/src/relay/backend/utils.cc b/src/relay/backend/utils.cc index 608d4cdb9f85..3aa65259d57a 100644 --- a/src/relay/backend/utils.cc +++ b/src/relay/backend/utils.cc @@ -178,6 +178,25 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n relay_primfuncs=" << node->relay_primfuncs << ")"; }); +ExecutorCodegenMetadata::ExecutorCodegenMetadata( + Array inputs, Array pools, Array devices, Integer num_outputs, + String executor, String mod_name, String interface_api, bool unpacked_api, + Map pool_inputs) { + auto n = make_object(); + n->inputs = inputs; + n->pools = pools; + n->devices = devices; + n->num_outputs = num_outputs; + n->executor = executor; + n->interface_api = interface_api; + n->unpacked_api = unpacked_api; + n->mod_name = mod_name; + n->pool_inputs = pool_inputs; + data_ = std::move(n); +} + +TVM_REGISTER_NODE_TYPE(ExecutorCodegenMetadataNode); + Array GetPassPrefix(bool is_homegeneous, bool is_vm) { Array pass_seqs; // TODO(mbs): Would be nice to get spans on all diagnostics, but since they arg forgotton diff --git a/src/relay/backend/utils.h b/src/relay/backend/utils.h index cb019083a9d5..595527b2aabf 100644 --- a/src/relay/backend/utils.h +++ b/src/relay/backend/utils.h @@ -64,7 +64,7 @@ class ExecutorCodegenMetadataNode : public Object { /*! \brief pool information for the main function */ Array pools; /*! \brief number of outputs of the main function */ - unsigned int num_outputs = 1; + Integer num_outputs = 1; /*! \brief device contexts information for the main function */ Array devices; /*! \brief the executor to be used to run the model */ @@ -78,7 +78,16 @@ class ExecutorCodegenMetadataNode : public Object { String mod_name = ""; - static constexpr const uint32_t _type_index = TypeIndex::kDynamic; + void VisitAttrs(tvm::AttrVisitor* v) { + v->Visit("inputs", &inputs); + v->Visit("pools", &pools); + v->Visit("num_outputs", &num_outputs); + v->Visit("devices", &devices); + v->Visit("executor", &executor); + v->Visit("unpacked_api", &unpacked_api); + v->Visit("pool_inputs", &pool_inputs); + } + static constexpr const char* _type_key = "MetadataObj"; TVM_DECLARE_FINAL_OBJECT_INFO(ExecutorCodegenMetadataNode, Object); }; @@ -89,26 +98,14 @@ class ExecutorCodegenMetadataNode : public Object { class ExecutorCodegenMetadata : public ObjectRef { public: TVM_DLL ExecutorCodegenMetadata(Array inputs, Array pools, - Array devices, int num_outputs, String executor, + Array devices, Integer num_outputs, String executor, String mod_name, String interface_api = "packed", bool unpacked_api = false, Map pool_inputs = - Map()) { - auto n = make_object(); - n->inputs = inputs; - n->pools = pools; - n->devices = devices; - n->num_outputs = num_outputs; - n->executor = executor; - n->interface_api = interface_api; - n->unpacked_api = unpacked_api; - n->mod_name = mod_name; - n->pool_inputs = pool_inputs; - data_ = std::move(n); - } + Map()); - TVM_DEFINE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, ExecutorCodegenMetadataNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ExecutorCodegenMetadataNode); + TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(ExecutorCodegenMetadata, ObjectRef, + ExecutorCodegenMetadataNode); }; /*! diff --git a/src/target/source/interface_c.cc b/src/target/source/interface_c.cc index ec1488488e59..f4cef74e8af9 100644 --- a/src/target/source/interface_c.cc +++ b/src/target/source/interface_c.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include @@ -41,11 +42,13 @@ using namespace tvm::relay::backend; class InterfaceCNode : public runtime::ModuleNode { public: InterfaceCNode(std::string module_name, Array inputs, Array outputs, - Array devices, int workspace_size) + Array pools, Array devices, + int workspace_size) : module_name_(module_name), inputs_(inputs), outputs_(outputs), devices_(devices), + pools_(FilterExternalPools(pools)), workspace_size_(workspace_size) {} const char* type_key() const { return "h"; } @@ -62,9 +65,25 @@ class InterfaceCNode : public runtime::ModuleNode { EmitBrief(code, "Device context pointers"); EmitStruct(code, "devices", devices_); } + if (!pools_.empty()) { + EmitBrief(code, "Workspace pool pointers"); + Array pool_names; + for (const tir::usmp::AllocatedPoolInfo pool : pools_) { + pool_names.push_back(pool->pool_info->pool_name); + } + EmitStruct(code, "workspace_pools", pool_names); + } EmitRunFunction(code); - EmitWorkspaceSize(code); + // Emit workspace + EmitIntegerValueMacro(code, "Workspace size", "WORKSPACE_SIZE", workspace_size_); + // Emit memory pool sizes + for (const tir::usmp::AllocatedPoolInfo pool : pools_) { + String pool_name = pool->pool_info->pool_name; + Integer pool_size = pool->allocated_size; + EmitIntegerValueMacro(code, SanitizeName(pool_name) + " size", + SanitizeName(pool_name) + "_WORKSPACE_POOL_SIZE", pool_size->value); + } EmitLowerHeaderGuard(code); return code.str(); @@ -116,11 +135,21 @@ class InterfaceCNode : public runtime::ModuleNode { code_stream << "};\n\n"; } + void EmitIntegerValueMacro(std::stringstream& code_stream, const std::string& brief_description, + const std::string& macro_name, int macro_value) { + EmitBrief(code_stream, brief_description); + std::string macro_name_prefixed = + ToCConstantStyle(PrefixGeneratedName({module_name_, macro_name})); + code_stream << "#define " << macro_name_prefixed << " " << macro_value << "\n"; + } + void EmitRunFunction(std::stringstream& code_stream) { std::string run_function = ToCVariableStyle(PrefixGeneratedName({module_name_, "run"})); std::string inputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "inputs"})); std::string outputs_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "outputs"})); std::string devices_struct = ToCVariableStyle(PrefixGeneratedName({module_name_, "devices"})); + std::string pools_struct = + ToCVariableStyle(PrefixGeneratedName({module_name_, "workspace_pools"})); code_stream << "/*!\n" << " * \\brief entrypoint function for TVM module \"" << module_name_ << "\"\n" @@ -130,40 +159,52 @@ class InterfaceCNode : public runtime::ModuleNode { if (!devices_.empty()) { code_stream << " * \\param devices Device context pointers for the module \n"; } + if (!pools_.empty()) { + code_stream << " * \\param workspace_pools Workspace memory pool pointers for the module \n"; + } code_stream << " */\n" - << "int32_t " << run_function << "(\n" - << " struct " << inputs_struct << "* inputs,\n"; + << "int32_t " << run_function << "(\n"; + std::stringstream call_args_ss; + call_args_ss << " struct " << inputs_struct << "* inputs,\n"; + call_args_ss << " struct " << outputs_struct << "* outputs,\n"; if (!devices_.empty()) { - code_stream << " struct " << outputs_struct << "* outputs,\n"; - code_stream << " struct " << devices_struct << "* devices\n"; - } else { - code_stream << " struct " << outputs_struct << "* outputs\n"; + call_args_ss << " struct " << devices_struct << "* devices,\n"; } - - code_stream << ");\n"; + if (!pools_.empty()) { + call_args_ss << " struct " << pools_struct << "* workspace_pools,\n"; + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + call_args_str.pop_back(); + code_stream << call_args_str << "\n);\n"; } - void EmitWorkspaceSize(std::stringstream& code_stream) { - std::string workspace_size_name = - ToCConstantStyle(PrefixGeneratedName({module_name_, "WORKSPACE_SIZE"})); - code_stream << "/*!\n" - << " * \\brief Workspace size for TVM module \"" << module_name_ << "\"\n" - << " */\n" - << "#define " << workspace_size_name << " " << workspace_size_ << "\n"; + Array FilterExternalPools( + const Array& pools) { + Array external_pools; + for (tir::usmp::AllocatedPoolInfo pool : pools) { + if (!pool->pool_info->is_internal) { + external_pools.push_back(pool); + } + } + return external_pools; } std::string module_name_; Array inputs_; Array outputs_; Array devices_; + Array pools_; int workspace_size_; }; runtime::Module InterfaceCCreate(std::string module_name, Array inputs, - Array outputs, Array devices, int workspace_size) { - auto n = make_object(module_name, inputs, outputs, devices, workspace_size); + Array outputs, Array pools, + Array devices, int workspace_size) { + auto n = + make_object(module_name, inputs, outputs, pools, devices, workspace_size); return runtime::Module(n); } diff --git a/src/target/source/source_module.cc b/src/target/source/source_module.cc index 8faac3f1d966..29d867ad719d 100644 --- a/src/target/source/source_module.cc +++ b/src/target/source/source_module.cc @@ -273,7 +273,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } call_args_ss << " " << input_var->name_hint << ","; } - for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs->value; ++i) { call_args_ss << "void* output" << i << ","; } for (const tir::Var& pool_var : metadata_->pools) { @@ -300,7 +300,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { for (unsigned int i = 0; i < metadata_->inputs.size(); ++i) { call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << i << "].v_handle))[0].data,"; } - for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs->value; ++i) { int j = metadata_->inputs.size() + i; call_args_ss << "((DLTensor*)(((TVMValue*)args)[" << j << "].v_handle))[0].data,"; } @@ -328,7 +328,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { entrypoint_arg_count++; run_func_arg_count++; } - for (unsigned int i = 0; i < metadata_->num_outputs; i++) { + for (int i = 0; i < metadata_->num_outputs->value; i++) { run_func_to_entry_point_args[run_func_arg_count] = Integer(entrypoint_arg_count); entrypoint_arg_count++; run_func_arg_count++; @@ -356,7 +356,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { // We are creating a copy of the set of pointers size_t number_of_io_tensors = - metadata_->inputs.size() + metadata_->num_outputs + metadata_->pools.size(); + metadata_->inputs.size() + metadata_->num_outputs->value + metadata_->pools.size(); code_ << "TVMValue tensors[" << number_of_io_tensors << "];\n"; std::unordered_map run_func_to_entry_point_args = @@ -395,7 +395,7 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { } call_args_ss << " " << relay::backend::SanitizeName(input_var->name_hint) << ","; } - for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs->value; ++i) { call_args_ss << "void* output" << i << ","; } for (const tir::Var& pool_var : metadata_->pools) { @@ -416,12 +416,29 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { code_ << ");\n"; code_ << "int32_t " << entrypoint_name << "("; - code_ << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; - if (!metadata_->devices.empty()) { - code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; - code_ << "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices"; - } else { - code_ << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs"; + { + std::stringstream call_args_ss; + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "inputs") << "* inputs,"; + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "outputs") << "* outputs,"; + if (!metadata_->pools.empty()) { + bool is_external_pools_present = false; + for (tir::Var pool_var : metadata_->pools) { + if (!IsInternalWorkspaceBuffer(pool_var)) { + is_external_pools_present = true; + break; + } + } + if (is_external_pools_present) { + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "workspace_pools") + << "* workspace_pools,"; + } + } + if (!metadata_->devices.empty()) { + call_args_ss << "struct " << runtime::get_name_mangled(mod_name, "devices") << "* devices,"; + } + std::string call_args_str = call_args_ss.str(); + call_args_str.pop_back(); + code_ << call_args_str; } code_ << ") {" @@ -432,17 +449,19 @@ class CSourceCrtMetadataModuleNode : public runtime::ModuleNode { for (const auto& input : metadata_->inputs) { call_args_ss << "inputs->" << relay::backend::SanitizeName(input->name_hint) << ","; } - if (metadata_->num_outputs == 1) { + if (metadata_->num_outputs->value == 1) { call_args_ss << "outputs->output,"; } else { - for (unsigned int i = 0; i < metadata_->num_outputs; ++i) { + for (int i = 0; i < metadata_->num_outputs->value; ++i) { call_args_ss << "outputs->output" << i << ","; } } for (const tir::Var& pool_var : metadata_->pools) { + String pool_name = metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name; if (IsInternalWorkspaceBuffer(pool_var)) { - call_args_ss << "&" << metadata_->pool_inputs.value()[pool_var]->pool_info->pool_name - << ","; + call_args_ss << "&" << pool_name << ","; + } else { + call_args_ss << "workspace_pools->" << relay::backend::SanitizeName(pool_name) << ","; } } for (const String& device : metadata_->devices) { diff --git a/src/tir/usmp/transform/assign_pool_info.cc b/src/tir/usmp/transform/assign_pool_info.cc index e75610ea0551..2386b5fef77d 100644 --- a/src/tir/usmp/transform/assign_pool_info.cc +++ b/src/tir/usmp/transform/assign_pool_info.cc @@ -46,23 +46,24 @@ class PoolInfoAssigner : public StmtExprMutator { ICHECK(main_func.defined()) << "main function is not in the module"; Optional target_host = main_func->GetAttr(tvm::attr::kTarget); ICHECK(target_host) << "main function does not have a target attr"; - Array pool_infos = - module->GetAttr>(tvm::attr::kPoolInfoIRModuleAttr) - .value_or({usmp::PoolInfo( - "global_workspace", - {{target_host.value(), String(PoolInfo::kTargetPoolReadWriteAccess)}}, + WorkspaceMemoryPools workspace_pools = + module->GetAttr(tvm::attr::kWorkspaceMemoryPools) + .value_or(WorkspaceMemoryPools({PoolInfo( + "global_workspace", {{target_host.value(), PoolInfo::kTargetPoolReadWriteAccess}}, PoolInfo::kUnrestrictedPoolSizeHint, PoolInfo::kUnknownClockFrequency, PoolInfo::kUnknownReadBandwidth, PoolInfo::kUnknownWriteBandwidth, 0, 0, - {{target_host.value(), 1}}, Bool(true))}); - for (const usmp::PoolInfo& pool_info : pool_infos) { + {{target_host.value(), 1}}, Bool(true))})); + Array pool_infos = workspace_pools->pools; + for (const PoolInfo& pool_info : pool_infos) { for (const auto& kv : pool_info->target_access) { - Target tgt = kv.first; - if (target_pool_infos_.find(tgt) == target_pool_infos_.end()) { - target_pool_infos_.Set(tgt, Array()); + Target target = kv.first; + String target_str = target->str(); + if (target_pool_infos_.find(target_str) == target_pool_infos_.end()) { + target_pool_infos_.Set(target_str, Array()); } - Array pool_info_arr = target_pool_infos_[tgt]; + Array pool_info_arr = target_pool_infos_[target_str]; pool_info_arr.push_back(pool_info); - target_pool_infos_.Set(tgt, pool_info_arr); + target_pool_infos_.Set(target_str, pool_info_arr); } } mod_ = module->ShallowCopy(); @@ -74,7 +75,7 @@ class PoolInfoAssigner : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) override; IRModule mod_; - Map> target_pool_infos_; + Map> target_pool_infos_; PrimFunc func_; }; @@ -83,7 +84,7 @@ Stmt PoolInfoAssigner::VisitStmt_(const AllocateNode* op) { ICHECK(tgt) << "The following PrimFunc does not have a target attr: \n" << func_; Map annotations = Map(op->annotations); if (op->annotations.find(kPoolCandidatesAllocateAttr) == op->annotations.end()) { - annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()]); + annotations.Set(kPoolCandidatesAllocateAttr, target_pool_infos_[tgt.value()->str()]); } Stmt body = VisitStmt(op->body); auto allocate = diff --git a/src/tir/usmp/utils.cc b/src/tir/usmp/utils.cc index b7c1a5f59f24..64428fd9c49e 100644 --- a/src/tir/usmp/utils.cc +++ b/src/tir/usmp/utils.cc @@ -22,6 +22,7 @@ * \brief Utilities for Unified Static Memory Planner */ +#include #include #include #include @@ -92,49 +93,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) << ",\n memory_pressure=" << node->memory_pressure << ")"; }); -PoolInfo::PoolInfo(String pool_name, Map target_access, Integer size_hint_bytes, - Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, - Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, - Integer write_latency_cycles, Map target_burst_bytes, - Bool is_internal) { - auto poolinfo_node = make_object(); - poolinfo_node->pool_name = pool_name; - poolinfo_node->size_hint_bytes = size_hint_bytes; - poolinfo_node->target_access = target_access; - poolinfo_node->clock_frequency_hz = clock_frequency_hz; - poolinfo_node->read_bandwidth_bytes_per_cycle = read_bandwidth_bytes_per_cycle; - poolinfo_node->write_bandwidth_bytes_per_cycle = write_bandwidth_bytes_per_cycle; - poolinfo_node->read_latency_cycles = read_latency_cycles; - poolinfo_node->write_latency_cycles = write_latency_cycles; - poolinfo_node->target_burst_bytes = target_burst_bytes; - poolinfo_node->is_internal = is_internal; - data_ = std::move(poolinfo_node); -} - -TVM_REGISTER_NODE_TYPE(PoolInfoNode); -TVM_REGISTER_GLOBAL("tir.usmp.PoolInfo") - .set_body_typed([](String pool_name, Map target_access, Integer size_hint_bytes, - Integer clock_frequency_hz, Integer read_bandwidth_bytes_per_cycle, - Integer write_bandwidth_bytes_per_cycle, Integer read_latency_cycles, - Integer write_latency_cycles, Map target_burst_bytes) { - return PoolInfo(pool_name, target_access, size_hint_bytes, clock_frequency_hz, - read_bandwidth_bytes_per_cycle, write_bandwidth_bytes_per_cycle, - read_latency_cycles, write_latency_cycles, target_burst_bytes, Bool(false)); - }); - -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) - .set_dispatch([](const ObjectRef& ref, ReprPrinter* p) { - auto* node = static_cast(ref.get()); - p->stream << "PoolInfoNode(\n" - << " pool_name=" << node->pool_name << ",\n target_access=" << node->target_access - << ",\n size_hint_bytes=" << node->size_hint_bytes - << ",\n clock_frequency_hz=" << node->clock_frequency_hz - << ",\n read_bandwidth_bytes_per_cycle=" << node->read_bandwidth_bytes_per_cycle - << ",\n write_bandwidth_bytes_per_cycle=" << node->write_bandwidth_bytes_per_cycle - << ",\n read_latency_cycles=" << node->read_latency_cycles - << ",\n write_latency_cycles=" << node->write_latency_cycles - << ",\n target_burst_bytes=" << node->target_burst_bytes << ")"; - }); PoolAllocation::PoolAllocation(PoolInfo pool_info, Integer byte_offset) { auto pool_allocation_node = make_object(); diff --git a/tests/cpp/relay_build_module_test.cc b/tests/cpp/relay_build_module_test.cc index 568338bcb868..859f587f5a11 100644 --- a/tests/cpp/relay_build_module_test.cc +++ b/tests/cpp/relay_build_module_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -128,7 +129,8 @@ TEST(Relay, BuildModule) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), + WorkspaceMemoryPools(), ""); std::string json = json_f(); tvm::runtime::Module mod = mod_f(); // run diff --git a/tests/cpp/runtime_test.cc b/tests/cpp/runtime_test.cc index c87639fffd2c..57686baf7b46 100644 --- a/tests/cpp/runtime_test.cc +++ b/tests/cpp/runtime_test.cc @@ -19,6 +19,7 @@ #include #include +#include #include #include #include @@ -114,7 +115,8 @@ TEST(Runtime, ZeroCopy) { targets.Set(0, llvm_tgt); auto relay_mod = tvm::IRModule::FromExpr(func); ICHECK(relay_mod.defined()) << "Module must be defined"; - build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), ""); + build_f(relay_mod, targets, llvm_tgt, Executor::Create("graph"), Runtime::Create("cpp"), + WorkspaceMemoryPools(), ""); // create graph executor std::string json = json_f(); tvm::runtime::Module mod = mod_f(); diff --git a/tests/cpp/target/source/interface_c_test.cc b/tests/cpp/target/source/interface_c_test.cc index 76496c20f4c9..71657a89e47f 100644 --- a/tests/cpp/target/source/interface_c_test.cc +++ b/tests/cpp/target/source/interface_c_test.cc @@ -22,6 +22,7 @@ #include #include #include +#include using ::testing::HasSubstr; @@ -29,7 +30,8 @@ namespace tvm { namespace codegen { runtime::Module InterfaceCCreate(std::string module_name, Array inputs, - Array outputs, Array devices, int workspace_size); + Array outputs, Array pools, + Array devices, int workspace_size); namespace { @@ -50,7 +52,7 @@ TEST(InterfaceAPI, ContainsHeaderGuards) { << "#endif // TVMGEN_ULTIMATE_CAT_SPOTTER_H_\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(upper_header_guard.str())); @@ -71,9 +73,8 @@ TEST(InterfaceAPI, ContainsRunFunction) { << ");\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); - ASSERT_THAT(header_source, HasSubstr(run_function.str())); } @@ -93,7 +94,32 @@ TEST(InterfaceAPI, ContainsRunFunctionWithDevices) { << ");\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(run_function.str())); +} + +TEST(InterfaceAPI, ContainsRunFunctionWithWorkspacePools) { + std::stringstream run_function; + + run_function << "/*!\n" + << " * \\brief entrypoint function for TVM module \"ultimate_cat_spotter\"\n" + << " * \\param inputs Input tensors for the module \n" + << " * \\param outputs Output tensors for the module \n" + << " * \\param workspace_pools Workspace memory pool pointers for the module \n" + << " */\n" + << "int32_t tvmgen_ultimate_cat_spotter_run(\n" + << " struct tvmgen_ultimate_cat_spotter_inputs* inputs,\n" + << " struct tvmgen_ultimate_cat_spotter_outputs* outputs,\n" + << " struct tvmgen_ultimate_cat_spotter_workspace_pools* workspace_pools\n" + << ");\n"; + + PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info = + tir::usmp::AllocatedPoolInfo(pool_info, 100000); + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(run_function.str())); @@ -110,7 +136,7 @@ TEST(InterfaceAPI, ContainsInputStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -125,7 +151,7 @@ TEST(InterfaceAPI, ContainsInputStructMany) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input1", "input2"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -140,7 +166,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input+1", "input+2"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(input_struct.str())); @@ -148,7 +174,7 @@ TEST(InterfaceAPI, ContainsInputStructSanitised) { TEST(InterfaceAPI, ContainsInputStructClash) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input+", "input-"}, {"output"}, {}, {}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } @@ -163,7 +189,7 @@ TEST(InterfaceAPI, ContainsOutputStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -178,7 +204,7 @@ TEST(InterfaceAPI, ContainsOutputStructMany) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output1", "output2"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -193,7 +219,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+1", "output-2"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(output_struct.str())); @@ -201,7 +227,7 @@ TEST(InterfaceAPI, ContainsOutputStructSanitised) { TEST(InterfaceAPI, ContainsOutputStructClash) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output+", "output-"}, {}, {}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } @@ -215,7 +241,7 @@ TEST(InterfaceAPI, NoDeviceAPIStructIfNoDevices) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, Not(HasSubstr(device_struct.str()))); @@ -232,7 +258,7 @@ TEST(InterfaceAPI, ContainsDeviceStructSingle) { << "};\n\n"; runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device"}, 0); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {"device"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); @@ -246,8 +272,8 @@ TEST(InterfaceAPI, ContainsDeviceStructMany) { << " void* device2;\n" << "};\n\n"; - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device1", "device2"}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, + {"device1", "device2"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); @@ -261,22 +287,22 @@ TEST(InterfaceAPI, ContainsDeviceStructSanitised) { << " void* device_2;\n" << "};\n\n"; - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+1", "device+2"}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, + {"device+1", "device+2"}, 0); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, HasSubstr(device_struct.str())); } TEST(InterfaceAPI, ContainsDeviceStructClash) { - runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {"device+", "device-"}, 0); + runtime::Module test_module = InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, + {"device+", "device-"}, 0); ASSERT_THROW(test_module->GetSource(), InternalError); } TEST(InterfaceAPI, ContainsWorkspaceSize) { runtime::Module test_module = - InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, 765432); + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {}, {}, 765432); std::string header_source = test_module->GetSource(); ASSERT_THAT(header_source, @@ -286,6 +312,119 @@ TEST(InterfaceAPI, ContainsWorkspaceSize) { HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_WORKSPACE_SIZE 765432")); } +TEST(InterfaceAPI, ContainsWorkspacePoolStructSingle) { + PoolInfo pool_info = PoolInfo("my_memory_pool", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info = + tir::usmp::AllocatedPoolInfo(pool_info, 100000); + + std::stringstream workspace_struct; + + workspace_struct + << "/*!\n" + << " * \\brief Workspace pool pointers for TVM module \"ultimate_cat_spotter\" \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_workspace_pools {\n" + << " void* my_memory_pool;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); + + ASSERT_THAT(header_source, + HasSubstr("* \\brief my_memory_pool size for TVM module \"ultimate_cat_spotter\"")); + + ASSERT_THAT( + header_source, + HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_WORKSPACE_POOL_SIZE 100000")); +} + +TEST(InterfaceAPI, ContainsWorkspacePoolStructMany) { + PoolInfo pool_info1 = PoolInfo("my_memory_pool_1", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info1 = + tir::usmp::AllocatedPoolInfo(pool_info1, 100000); + PoolInfo pool_info2 = PoolInfo("my_memory_pool_2", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info2 = + tir::usmp::AllocatedPoolInfo(pool_info2, 200000); + + std::stringstream workspace_struct; + + workspace_struct + << "/*!\n" + << " * \\brief Workspace pool pointers for TVM module \"ultimate_cat_spotter\" \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_workspace_pools {\n" + << " void* my_memory_pool_1;\n" + << " void* my_memory_pool_2;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, + {allocated_pool_info1, allocated_pool_info2}, {}, 0); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); + + ASSERT_THAT(header_source, + HasSubstr("* \\brief my_memory_pool_1 size for TVM module \"ultimate_cat_spotter\"")); + + ASSERT_THAT( + header_source, + HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_WORKSPACE_POOL_SIZE 100000")); + + ASSERT_THAT(header_source, + HasSubstr("* \\brief my_memory_pool_2 size for TVM module \"ultimate_cat_spotter\"")); + + ASSERT_THAT( + header_source, + HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_2_WORKSPACE_POOL_SIZE 200000")); +} + +TEST(InterfaceAPI, ContainsWorkspacePoolStructSanitized) { + PoolInfo pool_info = PoolInfo("my_memory_pool+1", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info = + tir::usmp::AllocatedPoolInfo(pool_info, 100000); + + std::stringstream workspace_struct; + + workspace_struct + << "/*!\n" + << " * \\brief Workspace pool pointers for TVM module \"ultimate_cat_spotter\" \n" + << " */\n" + << "struct tvmgen_ultimate_cat_spotter_workspace_pools {\n" + << " void* my_memory_pool_1;\n" + << "};\n\n"; + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, {allocated_pool_info}, {}, 0); + std::string header_source = test_module->GetSource(); + + ASSERT_THAT(header_source, HasSubstr(workspace_struct.str())); + + ASSERT_THAT(header_source, + HasSubstr("* \\brief my_memory_pool_1 size for TVM module \"ultimate_cat_spotter\"")); + + ASSERT_THAT( + header_source, + HasSubstr("#define TVMGEN_ULTIMATE_CAT_SPOTTER_MY_MEMORY_POOL_1_WORKSPACE_POOL_SIZE 100000")); +} + +TEST(InterfaceAPI, ContainsWorkspacePoolStructClash) { + PoolInfo pool_info1 = PoolInfo("my_memory_pool+", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info1 = + tir::usmp::AllocatedPoolInfo(pool_info1, 100000); + PoolInfo pool_info2 = PoolInfo("my_memory_pool-", {}); + tir::usmp::AllocatedPoolInfo allocated_pool_info2 = + tir::usmp::AllocatedPoolInfo(pool_info2, 200000); + + runtime::Module test_module = + InterfaceCCreate("ultimate_cat_spotter", {"input"}, {"output"}, + {allocated_pool_info1, allocated_pool_info2}, {}, 0); + ASSERT_THROW(test_module->GetSource(), InternalError); +} + } // namespace } // namespace codegen } // namespace tvm diff --git a/tests/micro/zephyr/test_utils.py b/tests/micro/zephyr/test_utils.py index e3a52dc79ab1..846dabee617c 100644 --- a/tests/micro/zephyr/test_utils.py +++ b/tests/micro/zephyr/test_utils.py @@ -210,7 +210,7 @@ def generate_project( model_files_path, arcname=os.path.relpath(model_files_path, tar_temp_dir) ) header_path = generate_c_interface_header( - lowered.libmod_name, ["input_1"], ["output"], [], 0, model_files_path + lowered.libmod_name, ["input_1"], ["output"], [], [], 0, model_files_path ) tf.add(header_path, arcname=os.path.relpath(header_path, tar_temp_dir)) diff --git a/tests/python/contrib/test_ethosu/test_networks.py b/tests/python/contrib/test_ethosu/test_networks.py index 68df66625856..de263c18f368 100644 --- a/tests/python/contrib/test_ethosu/test_networks.py +++ b/tests/python/contrib/test_ethosu/test_networks.py @@ -54,7 +54,7 @@ def test_forward_mobilenet_v1(accel_type): in_min, in_max = util.get_range_for_dtype_str(input_dtype) input_data = np.random.randint(in_min, high=in_max, size=input_shape, dtype=input_dtype) - relay_mod, params = convert_to_relay(tflite_model_buf, input_data, "input") + relay_mod, params = convert_to_relay(tflite_model_buf) input_data = {input_tensor: input_data} output_data = generate_ref_data(relay_mod, input_data) diff --git a/tests/python/relay/aot/aot_test_utils.py b/tests/python/relay/aot/aot_test_utils.py index 6900bdc2e6e1..7cbf74e9a811 100644 --- a/tests/python/relay/aot/aot_test_utils.py +++ b/tests/python/relay/aot/aot_test_utils.py @@ -161,16 +161,8 @@ def mangle_name(mod_name, name): def convert_to_relay( tflite_model_buf, - input_data, - input_node, ): """Convert a tflite model buffer in a Relay module""" - - def convert_to_list(x): - if not isinstance(x, list): - x = [x] - return x - # TFLite.Model.Model has changed to TFLite.Model from 1.14 to 2.1 try: import tflite.Model @@ -183,18 +175,7 @@ def convert_to_list(x): except ImportError: raise ImportError("The tflite package must be installed") - input_data = convert_to_list(input_data) - input_node = convert_to_list(input_node) - - shape_dict = {} - dtype_dict = {} - for i, e in enumerate(input_node): - shape_dict[e] = input_data[i].shape - dtype_dict[e] = input_data[i].dtype.name - - mod, params = relay.frontend.from_tflite( - tflite_model, shape_dict=shape_dict, dtype_dict=dtype_dict - ) + mod, params = relay.frontend.from_tflite(tflite_model) mod["main"] = relay.build_module.bind_params_by_name(mod["main"], params) return mod, params @@ -345,6 +326,16 @@ def emit_main_device_structs(main_file, devices, mod_name): main_file.write("};\n") +def emit_main_workspace_pool_structs(main_file, workspace_pool_names, mod_name): + if workspace_pool_names and len(workspace_pool_names) > 0: + main_file.write( + f"struct {mangle_name(mod_name, 'workspace_pools')} {mangle_name(mod_name, 'workspace_pools')} = {{" + ) + for workspace_pool_name in workspace_pool_names: + main_file.write(f"\t.{workspace_pool_name} = {workspace_pool_name},\n") + main_file.write("};\n") + + def emit_main_data_structs(main_file, input_map, output_list, mod_name): main_file.write( f"struct {mangle_name(mod_name, 'inputs')} {mangle_name(mod_name, 'inputs')} = {{" @@ -384,20 +375,25 @@ def emit_main_data_setup(main_file, input_map, output_list, mod_name): main_file.write("};\n") -def emit_main_c_interface_call(main_file, devices, mod_name): +def emit_main_c_interface_call(main_file, devices, workspace_pool_names, mod_name): + sub_strings = list() + sub_strings.append(f'{mangle_name(mod_name,"run")}(') + sub_strings.append(f'&{mangle_name(mod_name,"inputs")}, ') + sub_strings.append(f'&{mangle_name(mod_name,"outputs")}, ') + if workspace_pool_names: + sub_strings.append(f'&{mangle_name(mod_name,"workspace_pools")}, ') if devices: - main_file.write( - f'{mangle_name(mod_name,"run")}(' - f'&{mangle_name(mod_name,"inputs")}, ' - f'&{mangle_name(mod_name,"outputs")}, ' - f'&{mangle_name(mod_name,"devices")});\n' - ) - else: - main_file.write( - f'{mangle_name(mod_name,"run")}(' - f'&{mangle_name(mod_name,"inputs")}, ' - f'&{mangle_name(mod_name,"outputs")});\n' - ) + sub_strings.append(f'&{mangle_name(mod_name,"devices")}, ') + # Removing the last two characters that is a comma and a space + sub_strings[-1] = sub_strings[-1][:-2] + # Adding brackets and newline instead + sub_strings[-1] = sub_strings[-1] + ");\n" + + main_file_string = "" + for sub_string in sub_strings: + main_file_string += sub_string + + main_file.write(main_file_string) def emit_main_fake_packed_values(main_file): @@ -541,10 +537,21 @@ def create_main( if interface_api == "c": for compiled_model in compiled_models: model = compiled_model.model + executor_codegen_metadata = ( + compiled_model.executor_factory.executor_codegen_metadata + ) devices = compiled_model.executor_factory.get_devices() + workspace_pool_names = None + if executor_codegen_metadata.pool_inputs: + workspace_pool_names = [ + allocated_pool.pool_info.pool_name + for allocated_pool in dict(executor_codegen_metadata.pool_inputs).values() + if not allocated_pool.pool_info.is_internal + ] emit_main_device_structs(main_file, devices, model.name) + emit_main_workspace_pool_structs(main_file, workspace_pool_names, model.name) emit_main_data_structs(main_file, model.inputs, model.outputs, model.name) - emit_main_c_interface_call(main_file, devices, model.name) + emit_main_c_interface_call(main_file, devices, workspace_pool_names, model.name) else: emit_main_fake_packed_values(main_file) for compiled_model in compiled_models: @@ -599,8 +606,8 @@ def compile_models( enable_op_fusion: bool = True, pass_config: Dict[str, Any] = None, use_runtime_executor: bool = True, - target: str = "c", - target_opts: Dict = None, + target: tvm.target.Target = tvm.target.Target("c"), + workspace_memory_pools=None, ) -> List[AOTCompiledTestModel]: """ This method generates runtime.Modules for the tests @@ -617,9 +624,6 @@ def compile_models( "unpacked-api": use_unpacked_api, }, ) - if target_opts: - for key, val in target_opts.items(): - target += f" {key}={val}" config = {"tir.disable_vectorize": True} if pass_config: @@ -634,9 +638,10 @@ def compile_models( if use_runtime_executor: executor_factory = tvm.relay.build( model.module, - tvm.target.Target(target, host=target), + target, executor=executor, runtime=runtime, + workspace_memory_pools=workspace_memory_pools, params=model.params, mod_name=model.name, ) @@ -776,7 +781,6 @@ def run_and_check( print("Run command:\n", run_command) ret = subprocess_log_output(run_command, build_path, run_log_path) assert ret == 0 - with open(run_log_path) as run_log: assert AOT_SUCCESS_TOKEN in run_log.read() @@ -805,6 +809,11 @@ def compile_and_run( verbose: bool Prints commands to build and run AOT test runner """ + + if target_opts: + for key, val in target_opts.items(): + target += f" {key}={val}" + compiled_test_mods = compile_models( models=models, interface_api=interface_api, @@ -813,8 +822,7 @@ def compile_and_run( enable_op_fusion=enable_op_fusion, pass_config=runner.pass_config, use_runtime_executor=use_runtime_executor, - target=target, - target_opts=target_opts, + target=tvm.target.Target(target), ) run_and_check( diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 566566da1dce..ab51311d680a 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -529,7 +529,7 @@ def test_quant_mobilenet_tfl(): data_shape = (1, 224, 224, 3) in_min, in_max = (0, 255) data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8") - mod, params = convert_to_relay(tflite_model_buf, data, "input") + mod, params = convert_to_relay(tflite_model_buf) inputs = {"input": data} output_list = generate_ref_data(mod, inputs, params) compile_and_run( @@ -709,12 +709,12 @@ def test_constants_alignment(constants_byte_alignment): data = np.random.uniform(size=data_shape).astype("float32") inputs = {"data": data} output_list = generate_ref_data(mod, inputs, params) - target_opts = {"-constants-byte-alignment": constants_byte_alignment} + target = f"c -constants-byte-alignment={constants_byte_alignment}" compiled_test_mods = compile_models( AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), interface_api, use_unpacked_api, - target_opts=target_opts, + target=tvm.target.Target(target, host=target), ) source = compiled_test_mods[0].executor_factory.lib.imported_modules[0].get_source() assert f'__attribute__((section(".rodata.tvm"), aligned({constants_byte_alignment})))' in source diff --git a/tests/python/relay/aot/test_crt_aot_usmp.py b/tests/python/relay/aot/test_crt_aot_usmp.py index 6a040d9a9e79..a27609cc07ad 100644 --- a/tests/python/relay/aot/test_crt_aot_usmp.py +++ b/tests/python/relay/aot/test_crt_aot_usmp.py @@ -29,6 +29,7 @@ from tvm.relay.testing import byoc from tvm.relay.op.annotation import compiler_begin, compiler_end from tvm.relay.backend import Executor, Runtime +from tvm import WorkspaceMemoryPools, PoolInfo from aot_test_utils import ( AOTTestModel, AOTTestRunner, @@ -201,10 +202,31 @@ def test_byoc_microtvm(merge_compiler_regions): ) +def _get_relay_module_and_inputs_from_tflite_file(tflite_model_file): + with open(tflite_model_file, "rb") as f: + tflite_model_buf = f.read() + mod, params = convert_to_relay(tflite_model_buf) + + inputs = dict() + for param in mod["main"].params: + name = str(param.name_hint) + data_shape = [int(i) for i in param.type_annotation.shape] + dtype = str(param.type_annotation.dtype) + in_min, in_max = (np.iinfo(dtype).min, np.iinfo(dtype).max) + data = np.random.randint(in_min, high=in_max, size=data_shape, dtype=dtype) + inputs[name] = data + + return mod, inputs, params + + MOBILENET_V1_URL = ( "https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_2018_08_02/mobilenet_v1_1.0_224_quant.tgz", "mobilenet_v1_1.0_224_quant.tflite", ) +MOBILENET_V2_URL = ( + "https://storage.googleapis.com/download.tensorflow.org/models/tflite_11_05_08/mobilenet_v2_1.0_224_quant.tgz", + "mobilenet_v2_1.0_224_quant.tflite", +) @pytest.mark.parametrize( @@ -215,7 +237,7 @@ def test_byoc_microtvm(merge_compiler_regions): (MOBILENET_V1_URL, "hill_climb", 3240064), ], ) -def test_tflite_model(model_url, usmp_algo, workspace_size): +def test_tflite_model_u1_usecase(model_url, usmp_algo, workspace_size): """This checks for ML models and the memory used by them when using USMP with different algorithms""" pytest.importorskip("tflite") @@ -231,13 +253,7 @@ def test_tflite_model(model_url, usmp_algo, workspace_size): model_url[0], model_url[1], ) - with open(tflite_model_file, "rb") as f: - tflite_model_buf = f.read() - data_shape = (1, 224, 224, 3) - in_min, in_max = (0, 255) - data = np.random.randint(in_min, high=in_max, size=data_shape, dtype="uint8") - mod, params = convert_to_relay(tflite_model_buf, data, "input") - inputs = {"input": data} + mod, inputs, params = _get_relay_module_and_inputs_from_tflite_file(tflite_model_file) output_list = generate_ref_data(mod, inputs, params) compiled_test_mods = compile_models( @@ -265,3 +281,194 @@ def test_tflite_model(model_url, usmp_algo, workspace_size): runner=test_runner, interface_api=interface_api, ) + + +def _get_workspace_size_define_macro(pool_name: str, model_name="default") -> str: + """This function converts pool names to compiler generated + workspace pool size macros""" + + prefix = "TVMGEN_" + model_name.upper() + "_" + postfix = "_WORKSPACE_POOL_SIZE" + return prefix + pool_name.upper() + postfix + + +@pytest.mark.parametrize( + "model_url, usmp_algo", + [ + (MOBILENET_V1_URL, "greedy_by_size"), + ], +) +def test_tflite_model_u3_usecase_single_external_pool(model_url, usmp_algo): + """This checks for inference with USMP using external pool placed in the application""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + + pool_name = "my_memory_pool" + target = tvm.target.Target("c") + workspace_memory_pools = WorkspaceMemoryPools( + [PoolInfo(pool_name, {target: PoolInfo.READ_WRITE_ACCESS})] + ) + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo}, + prologue=f""" + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t {pool_name}[{_get_workspace_size_define_macro(pool_name)}]; + """, + ) + + tflite_model_file = tf_testing.get_workload_official( + model_url[0], + model_url[1], + ) + mod, inputs, params = _get_relay_module_and_inputs_from_tflite_file(tflite_model_file) + output_list = generate_ref_data(mod, inputs, params) + + compiled_test_mods = compile_models( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + workspace_memory_pools=workspace_memory_pools, + target=target, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + ) + + +@pytest.mark.parametrize( + "model_url, usmp_algo", + [ + (MOBILENET_V1_URL, "greedy_by_size"), + ], +) +def test_tflite_model_u3_usecase_two_external_pools(model_url, usmp_algo): + """This checks for inference using two external pools placed in the application""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + + target = tvm.target.Target("c") + workspace_memory_pools = WorkspaceMemoryPools( + [ + PoolInfo( + "my_memory_pool_1", {target: PoolInfo.READ_WRITE_ACCESS}, size_hint_bytes=2500000 + ), + PoolInfo("my_memory_pool_2", {target: PoolInfo.READ_WRITE_ACCESS}), + ] + ) + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo}, + prologue=f""" + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t my_memory_pool_1[{_get_workspace_size_define_macro("my_memory_pool_1")}]; + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t my_memory_pool_2[{_get_workspace_size_define_macro("my_memory_pool_2")}]; + """, + ) + + tflite_model_file = tf_testing.get_workload_official( + model_url[0], + model_url[1], + ) + mod, inputs, params = _get_relay_module_and_inputs_from_tflite_file(tflite_model_file) + output_list = generate_ref_data(mod, inputs, params) + + compiled_test_mods = compile_models( + AOTTestModel(module=mod, inputs=inputs, outputs=output_list, params=params), + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + workspace_memory_pools=workspace_memory_pools, + target=target, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + ) + + +@pytest.mark.parametrize( + "model_urls, usmp_algo", + [ + ((MOBILENET_V1_URL, MOBILENET_V2_URL), "greedy_by_size"), + ], +) +def test_tflite_model_u2_usecase_two_models_with_a_single_external_pool(model_urls, usmp_algo): + """This checks for inference using a single large enough common pool""" + pytest.importorskip("tflite") + + import tvm.relay.testing.tf as tf_testing + + use_unpacked_api = True + interface_api = "c" + + target = tvm.target.Target("c") + workspace_memory_pools = WorkspaceMemoryPools( + [PoolInfo("my_memory_pool", {target: PoolInfo.READ_WRITE_ACCESS})] + ) + test_runner = AOTTestRunner( + pass_config={"tir.usmp.enable": True, "tir.usmp.algorithm": usmp_algo}, + prologue=f""" + #define MAX(A, B) ((A > B) ? A : B) + __attribute__((section(".data.tvm"), aligned(16))) + static uint8_t my_memory_pool[MAX({_get_workspace_size_define_macro("my_memory_pool", "mod1")},{_get_workspace_size_define_macro("my_memory_pool", "mod2")})]; + """, + ) + + tflite_model_file1 = tf_testing.get_workload_official( + model_urls[0][0], + model_urls[0][1], + ) + mod1, inputs1, params1 = _get_relay_module_and_inputs_from_tflite_file(tflite_model_file1) + output_list1 = generate_ref_data(mod1, inputs1, params1) + + tflite_model_file2 = tf_testing.get_workload_official( + model_urls[1][0], + model_urls[1][1], + ) + mod2, inputs2, params2 = _get_relay_module_and_inputs_from_tflite_file(tflite_model_file2) + output_list2 = generate_ref_data(mod2, inputs2, params2) + + compiled_test_mods = compile_models( + [ + AOTTestModel( + name="mod1", module=mod1, inputs=inputs1, outputs=output_list1, params=params1 + ), + AOTTestModel( + name="mod2", module=mod2, inputs=inputs2, outputs=output_list2, params=params2 + ), + ], + interface_api=interface_api, + use_unpacked_api=use_unpacked_api, + pass_config=test_runner.pass_config, + workspace_memory_pools=workspace_memory_pools, + target=target, + ) + + for compiled_model in compiled_test_mods: + check_for_no_tvm_backendallocworkspace_calls(compiled_model.executor_factory.lib) + + run_and_check( + models=compiled_test_mods, + runner=test_runner, + interface_api=interface_api, + )