From d725d49bb0a7f50ccb8d055c0320c2511e4bc457 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 25 Feb 2021 15:19:41 -0800 Subject: [PATCH 01/13] Add node_storage_scope_ map and field to storage token. --- src/relay/backend/graph_plan_memory.cc | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 15173c2c79db..78898fa2c9be 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -46,6 +46,8 @@ struct StorageToken { int device_type{0}; /*! \brief The storage id */ int64_t storage_id{-1}; + /*! \brief The storage scope */ + std::string storage_scope; }; class StorageAllocaBaseVisitor : public ExprVisitor { @@ -143,8 +145,12 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { void CreateToken(const ExprNode* op, bool can_realloc) final { ICHECK(!token_map_.count(op)); std::vector tokens; + auto expr = GetRef(op); int device_type = - node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; + node_device_map_.count(expr) ? node_device_map_[expr]->value : 0; + std::string storage_scope = + node_storage_map_.count(expr) ? std::string(node_storage_map_[expr]) : "global"; + if (const auto* tuple_type = op->checked_type().as()) { for (Type t : tuple_type->fields) { const auto* ttype = t.as(); @@ -152,6 +158,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; + token->storage_scope = storage_scope; tokens.push_back(token); } } else { @@ -160,6 +167,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; + token->storage_scope = storage_scope; tokens.push_back(token); } token_map_[op] = tokens; @@ -180,6 +188,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { // allocator support::Arena* arena_; Map node_device_map_; + Map node_storage_map_; }; class StorageAllocator : public StorageAllocaBaseVisitor { From f3bf4f90a7e8c2ff6f5878b2bf9905f3e91c9915 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 25 Feb 2021 16:30:56 -0800 Subject: [PATCH 02/13] Use tuple for output memory plan info. --- src/relay/backend/graph_plan_memory.cc | 9 +++++---- src/relay/backend/graph_runtime_codegen.cc | 10 +++++----- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 78898fa2c9be..a4dd27514049 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -205,13 +205,13 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map > Plan(const Function& func) { + Map Plan(const Function& func) { prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); this->Run(func); // The value of smap contains two integer arrays where the first array // contains the planned storage ids and the second holds the device types. - Map > smap; + Map smap; int num_annotated_nodes = 0; int num_nodes = 0; @@ -226,7 +226,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { storage_ids.push_back(tok->storage_id); device_types.push_back(tok->device_type); } - smap.Set(GetRef(kv.first), Array({storage_ids, device_types})); + std::vector fields{IntegerArray{storage_ids}, IntegerArray{device_types}}; + smap.Set(GetRef(kv.first), runtime::ADT::Tuple(fields)); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { @@ -384,7 +385,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map > GraphPlanMemory(const Function& func) { +Map GraphPlanMemory(const Function& func) { return StorageAllocator().Plan(func); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7ed150495104..1c785fe74b0e 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -253,13 +253,13 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_info; - for (auto& v : storage_device_info[0]) { + for (auto& v : Downcast(storage_device_info[0])) { storage_info.push_back(v->value); } node->attrs_["storage_id"] = std::move(storage_info); // type std::vector device_types; - for (auto& v : storage_device_info[1]) { + for (auto& v : Downcast(storage_device_info[1])) { device_types.push_back(v->value); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); @@ -320,7 +320,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; + param_storage_ids_[name] = Downcast(storage_device_map_[expr][0])[0]->value; params_[name] = op->data; return to_return; } @@ -382,7 +382,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; + auto call_dev_type = Downcast(device_type)[0]->value; // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. @@ -548,7 +548,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + Map storage_device_map_; /*! \brief lowered funcs */ std::unordered_map lowered_funcs_; /*! \brief name map */ From 3ab43ab7fde407b7307a591bbb2a357a098e761a Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Thu, 25 Feb 2021 21:48:10 -0800 Subject: [PATCH 03/13] Add storage_scope to memory planner output and serialize/deserialize in graph_runtime. --- src/relay/backend/graph_plan_memory.cc | 7 ++++--- src/relay/backend/graph_runtime_codegen.cc | 13 ++++++++++++- src/runtime/graph/graph_runtime.h | 10 ++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index a4dd27514049..ec0d3d681440 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -32,8 +32,6 @@ namespace tvm { namespace relay { -using IntegerArray = Array; - struct StorageToken { /*! \brief Reference counter */ int ref_counter{0}; @@ -218,6 +216,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { for (const auto& kv : token_map_) { std::vector storage_ids; std::vector device_types; + std::vector storage_scopes; for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; @@ -225,8 +224,10 @@ class StorageAllocator : public StorageAllocaBaseVisitor { num_nodes++; storage_ids.push_back(tok->storage_id); device_types.push_back(tok->device_type); + storage_scopes.push_back(tok->storage_scope); } - std::vector fields{IntegerArray{storage_ids}, IntegerArray{device_types}}; + std::vector fields{ + Array{storage_ids}, Array{device_types}, Array{storage_scopes}}; smap.Set(GetRef(kv.first), runtime::ADT::Tuple(fields)); } // Either all or none of the nodes should be annotated. diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 1c785fe74b0e..02e16eb7218a 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -250,13 +250,19 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_info; for (auto& v : Downcast(storage_device_info[0])) { storage_info.push_back(v->value); } node->attrs_["storage_id"] = std::move(storage_info); + // storage scope + std::vector storage_scope; + for (auto& v : Downcast>(storage_device_info[2])) { + storage_scope.push_back(std::string(v)); + } + node->attrs_["storage_scope"] = std::move(storage_scope); // type std::vector device_types; for (auto& v : Downcast(storage_device_info[1])) { @@ -472,12 +478,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_ids; + std::vector storage_scopes; std::vector device_types; std::vector dltypes; std::vector node_row_ptr{0}; for (auto node : nodes_) { const auto& shape_vec = dmlc::get(node->attrs_["shape"]); const auto& storage_id = dmlc::get>(node->attrs_["storage_id"]); + const auto& storage_scope = dmlc::get>(node->attrs_["storage_scope"]); const auto& dtype_vec = dmlc::get>(node->attrs_["dtype"]); ICHECK_EQ(node->num_outputs_, shape_vec.size()); @@ -486,6 +494,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorattrs_.count("device_index")) { const auto& dev_types = dmlc::get>(node->attrs_["device_index"]); device_types.insert(device_types.end(), dev_types.begin(), dev_types.end()); @@ -501,6 +510,8 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_id; std::vector device_index; std::vector dltype; + std::vector storage_scope; std::vector> shape; // The graph attribute fields. void Load(dmlc::JSONReader* reader) { @@ -299,6 +300,15 @@ class TVM_DLL GraphRuntime : public ModuleNode { reader->Read(&storage_id); ICHECK(!reader->NextArrayItem()); bitmask |= 2; + } else if (key == "storage_scope") { + reader->BeginArray(); + ICHECK(reader->NextArrayItem()); + reader->Read(&type); + ICHECK_EQ(type, "list_str"); + ICHECK(reader->NextArrayItem()); + reader->Read(&storage_scope); + ICHECK(!reader->NextArrayItem()); + bitmask |= 1; } else if (key == "shape") { reader->BeginArray(); ICHECK(reader->NextArrayItem()); From 976c4e5e489eced874d481b0a96126b7b19e8357 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Sat, 27 Feb 2021 22:31:11 -0800 Subject: [PATCH 04/13] Add CollectStorageInfo declaration for use in graph memory planner. --- include/tvm/relay/analysis.h | 9 +++++++++ src/relay/backend/graph_plan_memory.cc | 1 + 2 files changed, 10 insertions(+) diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 5dd837038731..904572021deb 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -220,6 +220,15 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); +/*! + * \brief Collect the device mapping information of each expression. + * + * \param expr The expression. + * + * \return The device mapping. + */ +TVM_DLL Map CollectStorageInfo(const Expr& expr); + /*! * \brief Collect the device anntation operators. * diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index ec0d3d681440..0259a8be2984 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -133,6 +133,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { std::unordered_map > GetInitTokenMap( const Function& func) { node_device_map_ = CollectDeviceInfo(func); + node_storage_map_ = CollectStorageInfo(func); this->Run(func); return std::move(token_map_); } From ec46f95f2ad15af982f8e1f53849d571d5684eff Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Mon, 1 Mar 2021 15:13:18 -0800 Subject: [PATCH 05/13] Temporarily hard code intermediate texture storage ids to be persistent and reallocable via pools. This will need to change but serves as a starting point. --- src/relay/backend/graph_plan_memory.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 0259a8be2984..ac50da3d28c5 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -249,7 +249,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { - if (can_realloc) { + if (can_realloc && tok->storage_scope == "global") { tokens.push_back(Request(tok)); } else { // Allocate a new token, From ed43583ab625f6e4fee53fa54639b79367892cef Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 2 Mar 2021 14:35:26 -0800 Subject: [PATCH 06/13] Use storage scope attribute when doing storage allocations in GraphRuntime::SetupStorage(). --- src/runtime/graph/graph_runtime.cc | 5 ++++- src/runtime/graph/graph_runtime.h | 1 + 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 6c51e711aef1..eccb4c9a9dfd 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -279,6 +279,7 @@ void GraphRuntime::SetupStorage() { // Find the maximum space size. for (size_t i = 0; i < attrs_.shape.size(); ++i) { int storage_id = attrs_.storage_id[i]; + std::string storage_scope = attrs_.storage_scope[i]; // Use the fallback device if no device index is available. int device_type = static_cast(ctxs_[0].device_type); if (!attrs_.device_index.empty()) { @@ -315,6 +316,7 @@ void GraphRuntime::SetupStorage() { pool_entry[sid].param_data_entry = i; pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); pool_entry[sid].device_type = device_type; + pool_entry[sid].scope = storage_scope; } // Allocate the space. @@ -330,7 +332,8 @@ void GraphRuntime::SetupStorage() { } else { std::vector shape; shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + Optional scope = String(pit.scope); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx, scope)); } } diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index d78a0ce6d72c..dce2a9bca344 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -184,6 +184,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { int device_type; int param_data_entry; NDArray linked_param; + std::string scope; // PoolEntry(int s, int dev_type, void* pre_linked_param) : // size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {} }; From 8011e57ec43f87d29e1176b75904b5aa44791156 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Sun, 7 Mar 2021 15:16:15 -0800 Subject: [PATCH 07/13] Support passing buffer binds to compile engine during lowering. --- python/tvm/driver/build_module.py | 8 +++++++- python/tvm/relay/backend/_backend.py | 13 +++++++++---- src/relay/backend/compile_engine.cc | 23 +++++++++++++++++------ src/relay/backend/compile_engine.h | 4 ++-- 4 files changed, 35 insertions(+), 13 deletions(-) diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5eaecb422163..90e97dc0ee08 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -56,7 +56,13 @@ def get_binds(args, compact=False, binds=None): arg_list: list The list of symbolic buffers of arguments. """ - binds = {} if binds is None else binds.copy() + + if isinstance(binds, container.Map): + binds = {k : v for (k, v) in binds.items()} + elif isinstance(binds, dict): + binds = binds.copy() + elif binds == None: + binds = {} arg_list = [] for x in args: if isinstance(x, tensor.Tensor): diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 65b0c0ba87c7..733d7fcffe80 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,7 +20,7 @@ @tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): +def lower(sch, inputs, func_name, source_func, binds=None): """Backend function for lowering. Parameters @@ -37,6 +37,11 @@ def lower(sch, inputs, func_name, source_func): source-func : tvm.relay.Function The source function to be lowered. + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + Returns ------- mod : tvm.IRModule @@ -46,7 +51,7 @@ def lower(sch, inputs, func_name, source_func): import traceback try: - f = tvm.driver.lower(sch, inputs, name=func_name) + f = tvm.driver.lower(sch, inputs, name=func_name, binds=binds) # logging.debug("lower function %s", func_name) # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) except Exception: @@ -59,7 +64,7 @@ def lower(sch, inputs, func_name, source_func): @tvm._ffi.register_func("relay.backend.build") -def build(mod, target, target_host=None): +def build(mod, target, target_host=None, binds=None): """Backend build function. Parameters @@ -80,7 +85,7 @@ def build(mod, target, target_host=None): """ if target_host == "": target_host = None - return tvm.driver.build(mod, target=target, target_host=target_host) + return tvm.driver.build(mod, target=target, target_host=target_host, binds=binds) @tvm._ffi.register_func("relay._tensor_value_repr") diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ae975a5f3240..6523c8db6ad3 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -612,11 +612,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } + CachedFunc Lower(const CCacheKey& key, const Array& buffers) { + return LowerInternal(key, buffers)->cached_func; } // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - CCacheValue value = LowerInternal(key); + PackedFunc JIT(const CCacheKey& key, const Array& buffers) final { + CCacheValue value = LowerInternal(key, buffers); if (value->packed_func != nullptr) return value->packed_func; // build the function. tvm::runtime::Module m; @@ -711,7 +712,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key, const Array& buffers = {}) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -762,9 +763,19 @@ class CompileEngineImpl : public CompileEngineNode { for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } + + // build the bind map + Map binds; + if (buffers.size() == all_args.size()) { + for (size_t i = 0; i < all_args.size(); i++) { + auto& arg = all_args[i]; + binds.Set(arg, buffers[i]); + } + } + // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); + cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func, binds); } else { using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); @@ -876,7 +887,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); + .set_body_typed([](CompileEngine self, CCacheKey key, Array buffers) { return self->Lower(key, buffers); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index d7628e7a5bdf..e58dc499f029 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -201,13 +201,13 @@ class CompileEngineNode : public Object { * \param key The key to the cached function. * \return The result. */ - virtual CachedFunc Lower(const CCacheKey& key) = 0; + virtual CachedFunc Lower(const CCacheKey& key, const Array& buffers = {}) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. * \return The result. */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; + virtual PackedFunc JIT(const CCacheKey& key, const Array& buffers = {}) = 0; /*! * \brief Lower the shape function. * \param key The key to the cached function. From 30215b916de27177bc31935fa6ca10c9de64a336 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 17 Mar 2021 13:22:56 -0700 Subject: [PATCH 08/13] Introduce CollectBufferBinds packed func which can provide buffers to be bound during lower. The packed func is specialized over target (and target attr:device) type. --- src/relay/backend/graph_runtime_codegen.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 02e16eb7218a..f9dc29ab94ab 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -407,8 +407,18 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator buffers; + std::string fbuffer_prefix = "relay.backend." + target->kind->name; + if (Optional t_device = target->GetAttr("device")) { + fbuffer_prefix += ("." + t_device.value()); + } + if (const auto* f = runtime::Registry::Get(fbuffer_prefix + "._CollectBufferBinds")) { + buffers = (*f)(GetRef(op), storage_device_map_); + } + CCacheKey key = (*pf0)(func, target); - CachedFunc lowered_func = (*pf1)(compile_engine_, key); + CachedFunc lowered_func = (*pf1)(compile_engine_, key, buffers); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } From 130cf5f7ab235b2a568a500628661e1c172b474a Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 17 Mar 2021 13:27:17 -0700 Subject: [PATCH 09/13] Properly handle storage scope for multiple output nodes. --- include/tvm/relay/analysis.h | 4 ++-- src/relay/backend/graph_plan_memory.cc | 24 ++++++++++++++++-------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index 904572021deb..a78c9297ec9c 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -221,13 +221,13 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); TVM_DLL Map CollectDeviceInfo(const Expr& expr); /*! - * \brief Collect the device mapping information of each expression. + * \brief Collect the output storage information of each expression. * * \param expr The expression. * * \return The device mapping. */ -TVM_DLL Map CollectStorageInfo(const Expr& expr); +Map> CollectStorageInfo(const Expr& expr); /*! * \brief Collect the device anntation operators. diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index ac50da3d28c5..36a2f62777cf 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -45,7 +45,7 @@ struct StorageToken { /*! \brief The storage id */ int64_t storage_id{-1}; /*! \brief The storage scope */ - std::string storage_scope; + std::string storage_scope{"global"}; }; class StorageAllocaBaseVisitor : public ExprVisitor { @@ -147,17 +147,23 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { auto expr = GetRef(op); int device_type = node_device_map_.count(expr) ? node_device_map_[expr]->value : 0; - std::string storage_scope = - node_storage_map_.count(expr) ? std::string(node_storage_map_[expr]) : "global"; + + Optional> storage_info; + if (node_storage_map_.count(GetRef(op))) { + storage_info = node_storage_map_[GetRef(op)]; + } if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); + if (storage_info.defined()) { ICHECK_EQ(tuple_type->fields.size(), storage_info.value().size()); } + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + const auto* ttype = tuple_type->fields[i].as(); ICHECK(ttype); StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; - token->storage_scope = storage_scope; + if (storage_info.defined()) { + token->storage_scope = storage_info.value()[i]; + } tokens.push_back(token); } } else { @@ -166,7 +172,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; - token->storage_scope = storage_scope; + if (storage_info.defined()) { + token->storage_scope = storage_info.value()[0]; + } tokens.push_back(token); } token_map_[op] = tokens; @@ -187,7 +195,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { // allocator support::Arena* arena_; Map node_device_map_; - Map node_storage_map_; + Map> node_storage_map_; }; class StorageAllocator : public StorageAllocaBaseVisitor { From 9bc10af4193f07f001fbe7d2b9489b18b521c000 Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Fri, 12 Mar 2021 13:58:36 -0800 Subject: [PATCH 10/13] Add bound buffers to compile engine cache key. --- src/relay/backend/compile_engine.cc | 7 ++++--- src/relay/backend/compile_engine.h | 9 ++++++++- src/relay/backend/graph_runtime_codegen.cc | 2 +- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 6523c8db6ad3..7a9da640e6a1 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -64,10 +64,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target) { +CCacheKey::CCacheKey(Function source_func, Target target, Array buffers) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->buffers = std::move(buffers); data_ = std::move(n); } @@ -874,8 +875,8 @@ TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); + .set_body_typed([](Function source_func, Target target, Array buffers = {}) { + return CCacheKey(source_func, target, buffers); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index e58dc499f029..184fa6c3b542 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -114,6 +114,8 @@ class CCacheKeyNode : public Object { Function source_func; /*! \brief The hardware target.*/ Target target; + /*! \brief Any buffers bound to the source function. */ + Array buffers; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); @@ -148,8 +150,9 @@ class CCacheKey : public ObjectRef { * \brief The constructor * \param source_func The source function. * \param target The target device. + * \param buffers Optional bound buffers */ - TVM_DLL CCacheKey(Function source_func, Target target); + TVM_DLL CCacheKey(Function source_func, Target target, Array buffers = {}); const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator @@ -269,6 +272,10 @@ inline size_t CCacheKeyNode::Hash() const { inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; + if (other->buffers.size() != this->buffers.size()) return false; + for (size_t i = 0; i < other->buffers.size(); i++) { + if (!tvm::StructuralEqual()(other->buffers[i], this->buffers[i])) return false; + } return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); } diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index f9dc29ab94ab..31580c4bc076 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -417,7 +417,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator(op), storage_device_map_); } - CCacheKey key = (*pf0)(func, target); + CCacheKey key = (*pf0)(func, target, buffers); CachedFunc lowered_func = (*pf1)(compile_engine_, key, buffers); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); From 5aaf9db4c5e8fd6021aa33d81f8c5ec697d7191c Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 16 Mar 2021 13:35:13 -0700 Subject: [PATCH 11/13] Change TargetsMap in GRC to from std::unordered_map to tvm::Map. --- src/relay/backend/graph_runtime_codegen.cc | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 31580c4bc076..fd6ff08cac43 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -49,7 +49,7 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using TargetsMap = std::unordered_map; +using TargetsMap = Map; /*! \brief Lowered outputs */ struct LoweredOutput { @@ -585,15 +585,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; + << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; - TargetsMap targets; - for (const auto& it : tmp) { - auto dev_type = it.first.as(); - ICHECK(dev_type); - targets[dev_type->value] = it.second; - } + TargetsMap targets = args[1]; codegen_ = std::make_shared(reinterpret_cast(mod), targets); }); From a9e62e556f50b8e3c6c512c48060eb328df93d2b Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Wed, 17 Mar 2021 13:31:31 -0700 Subject: [PATCH 12/13] Dispatch to target specific storage scope collection function (CollectStorageInfo) in memory planning. This allows users to implement a target specific storage mapping as part of memory planning that depends on the target specific constraints. --- include/tvm/relay/analysis.h | 9 ---- src/relay/backend/graph_plan_memory.cc | 50 +++++++++++++++++++--- src/relay/backend/graph_runtime_codegen.cc | 2 +- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/include/tvm/relay/analysis.h b/include/tvm/relay/analysis.h index a78c9297ec9c..5dd837038731 100644 --- a/include/tvm/relay/analysis.h +++ b/include/tvm/relay/analysis.h @@ -220,15 +220,6 @@ TVM_DLL tvm::Array AllTypeVars(const Type& t, const IRModule& mod); */ TVM_DLL Map CollectDeviceInfo(const Expr& expr); -/*! - * \brief Collect the output storage information of each expression. - * - * \param expr The expression. - * - * \return The device mapping. - */ -Map> CollectStorageInfo(const Expr& expr); - /*! * \brief Collect the device anntation operators. * diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 36a2f62777cf..0751c3ea2ff3 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -26,12 +26,17 @@ #include #include #include +#include #include "../../support/arena.h" namespace tvm { namespace relay { +using TargetsMap = Map; +using Texture2DShape = runtime::Texture2DShape; +constexpr auto Is2DStorage = runtime::IsTextureStorage; + struct StorageToken { /*! \brief Reference counter */ int ref_counter{0}; @@ -125,15 +130,48 @@ class StorageAllocaBaseVisitor : public ExprVisitor { virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; }; +/*! + * \brief Collect the target specific tensor storage info for each expression's output. + * \param expr The expression. + * \param expr The device id map which can be used to infer device specific storage scope availability. + * \param expr The target mapping from device id to target. + * \return The device based storage mapping. + */ +Map> CollectStorageInfo(const Expr& expr, const Map& dev_map, const TargetsMap& target_map) { + auto less = [](Integer i, Integer j) { + auto i_imm = i.as(); + auto j_imm = j.as(); + ICHECK(i_imm && j_imm); + return i_imm->value < j_imm->value; + }; + std::set device_types(less); + for (auto& kv : target_map) { + device_types.insert(kv.first); + } + std::string ftarget_prefix = "relay.backend"; + for (auto& dev_id : device_types) { + Target target = target_map[dev_id]; + ftarget_prefix += ("." + target->kind->name); + if (Optional t_device = target->GetAttr("device")) { + ftarget_prefix += ("." + t_device.value()); + } + } + Map> storage_info = {}; + if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) { + storage_info = (*f)(expr, dev_map, target_map); + } + return storage_info; +} + class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ std::unordered_map > GetInitTokenMap( - const Function& func) { + const Function& func, const TargetsMap& targets) { node_device_map_ = CollectDeviceInfo(func); - node_storage_map_ = CollectStorageInfo(func); + node_storage_map_ = CollectStorageInfo(func, node_device_map_, targets); this->Run(func); return std::move(token_map_); } @@ -212,8 +250,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map Plan(const Function& func) { - prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); + Map Plan(const Function& func, const TargetsMap& targets) { + prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func, targets); this->Run(func); // The value of smap contains two integer arrays where the first array @@ -395,8 +433,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map GraphPlanMemory(const Function& func) { - return StorageAllocator().Plan(func); +Map GraphPlanMemory(const Function& func, const TargetsMap& targets) { + return StorageAllocator().Plan(func, targets); } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index fd6ff08cac43..51a4f14b21c2 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -191,7 +191,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorparams) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); From 53cfedbc98245aff57ce964a801aa3b2f8e10eca Mon Sep 17 00:00:00 2001 From: Chris Sullivan Date: Tue, 16 Mar 2021 15:26:02 -0700 Subject: [PATCH 13/13] Rename target prefix --- src/relay/backend/graph_runtime_codegen.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 51a4f14b21c2..dd6f74247531 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -409,11 +409,11 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator buffers; - std::string fbuffer_prefix = "relay.backend." + target->kind->name; + std::string ftarget_prefix = "relay.backend." + target->kind->name; if (Optional t_device = target->GetAttr("device")) { - fbuffer_prefix += ("." + t_device.value()); + ftarget_prefix += ("." + t_device.value()); } - if (const auto* f = runtime::Registry::Get(fbuffer_prefix + "._CollectBufferBinds")) { + if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectBufferBinds")) { buffers = (*f)(GetRef(op), storage_device_map_); }