diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 5eaecb422163..90e97dc0ee08 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -56,7 +56,13 @@ def get_binds(args, compact=False, binds=None): arg_list: list The list of symbolic buffers of arguments. """ - binds = {} if binds is None else binds.copy() + + if isinstance(binds, container.Map): + binds = {k : v for (k, v) in binds.items()} + elif isinstance(binds, dict): + binds = binds.copy() + elif binds == None: + binds = {} arg_list = [] for x in args: if isinstance(x, tensor.Tensor): diff --git a/python/tvm/relay/backend/_backend.py b/python/tvm/relay/backend/_backend.py index 65b0c0ba87c7..733d7fcffe80 100644 --- a/python/tvm/relay/backend/_backend.py +++ b/python/tvm/relay/backend/_backend.py @@ -20,7 +20,7 @@ @tvm._ffi.register_func("relay.backend.lower") -def lower(sch, inputs, func_name, source_func): +def lower(sch, inputs, func_name, source_func, binds=None): """Backend function for lowering. Parameters @@ -37,6 +37,11 @@ def lower(sch, inputs, func_name, source_func): source-func : tvm.relay.Function The source function to be lowered. + binds : dict of :any:`Tensor` to :any:`Buffer`, optional + Dictionary that maps the Tensor to Buffer which specified the data layout + requirement of the function. By default, a new compact buffer is created + for each tensor in the argument. + Returns ------- mod : tvm.IRModule @@ -46,7 +51,7 @@ def lower(sch, inputs, func_name, source_func): import traceback try: - f = tvm.driver.lower(sch, inputs, name=func_name) + f = tvm.driver.lower(sch, inputs, name=func_name, binds=binds) # logging.debug("lower function %s", func_name) # logging.debug("%s", _build.lower(sch, inputs, simple_mode=True)) except Exception: @@ -59,7 +64,7 @@ def lower(sch, inputs, func_name, source_func): @tvm._ffi.register_func("relay.backend.build") -def build(mod, target, target_host=None): +def build(mod, target, target_host=None, binds=None): """Backend build function. Parameters @@ -80,7 +85,7 @@ def build(mod, target, target_host=None): """ if target_host == "": target_host = None - return tvm.driver.build(mod, target=target, target_host=target_host) + return tvm.driver.build(mod, target=target, target_host=target_host, binds=binds) @tvm._ffi.register_func("relay._tensor_value_repr") diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index ae975a5f3240..7a9da640e6a1 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -64,10 +64,11 @@ LoweredOutput::LoweredOutput(tvm::Array outputs, OpImplementation im data_ = std::move(n); } -CCacheKey::CCacheKey(Function source_func, Target target) { +CCacheKey::CCacheKey(Function source_func, Target target, Array buffers) { auto n = make_object(); n->source_func = std::move(source_func); n->target = std::move(target); + n->buffers = std::move(buffers); data_ = std::move(n); } @@ -612,11 +613,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator> class CompileEngineImpl : public CompileEngineNode { public: // Lower the function. - CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; } + CachedFunc Lower(const CCacheKey& key, const Array& buffers) { + return LowerInternal(key, buffers)->cached_func; } // For now, build one module per function. - PackedFunc JIT(const CCacheKey& key) final { - CCacheValue value = LowerInternal(key); + PackedFunc JIT(const CCacheKey& key, const Array& buffers) final { + CCacheValue value = LowerInternal(key, buffers); if (value->packed_func != nullptr) return value->packed_func; // build the function. tvm::runtime::Module m; @@ -711,7 +713,7 @@ class CompileEngineImpl : public CompileEngineNode { private: // implement lowered func - CCacheValue LowerInternal(const CCacheKey& key) { + CCacheValue LowerInternal(const CCacheKey& key, const Array& buffers = {}) { std::lock_guard lock(mutex_); CCacheValue value; auto it = cache_.find(key); @@ -762,9 +764,19 @@ class CompileEngineImpl : public CompileEngineNode { for (te::Tensor arg : cache_node->outputs) { all_args.push_back(arg); } + + // build the bind map + Map binds; + if (buffers.size() == all_args.size()) { + for (size_t i = 0; i < all_args.size(); i++) { + auto& arg = all_args[i]; + binds.Set(arg, buffers[i]); + } + } + // lower the function if (const auto* f = runtime::Registry::Get("relay.backend.lower")) { - cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func); + cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func, binds); } else { using tvm::transform::PassContext; With fresh_pass_ctx_scope(PassContext::Create()); @@ -863,8 +875,8 @@ TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput") }); TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey") - .set_body_typed([](Function source_func, Target target) { - return CCacheKey(source_func, target); + .set_body_typed([](Function source_func, Target target, Array buffers = {}) { + return CCacheKey(source_func, target, buffers); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() { @@ -876,7 +888,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower") - .set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); }); + .set_body_typed([](CompileEngine self, CCacheKey key, Array buffers) { return self->Lower(key, buffers); }); TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc") .set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); }); diff --git a/src/relay/backend/compile_engine.h b/src/relay/backend/compile_engine.h index d7628e7a5bdf..184fa6c3b542 100644 --- a/src/relay/backend/compile_engine.h +++ b/src/relay/backend/compile_engine.h @@ -114,6 +114,8 @@ class CCacheKeyNode : public Object { Function source_func; /*! \brief The hardware target.*/ Target target; + /*! \brief Any buffers bound to the source function. */ + Array buffers; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("source_func", &source_func); @@ -148,8 +150,9 @@ class CCacheKey : public ObjectRef { * \brief The constructor * \param source_func The source function. * \param target The target device. + * \param buffers Optional bound buffers */ - TVM_DLL CCacheKey(Function source_func, Target target); + TVM_DLL CCacheKey(Function source_func, Target target, Array buffers = {}); const CCacheKeyNode* operator->() const { return static_cast(get()); } // comparator @@ -201,13 +204,13 @@ class CompileEngineNode : public Object { * \param key The key to the cached function. * \return The result. */ - virtual CachedFunc Lower(const CCacheKey& key) = 0; + virtual CachedFunc Lower(const CCacheKey& key, const Array& buffers = {}) = 0; /*! * \brief Just in time compile to get a PackedFunc. * \param key The key to the cached function. * \return The result. */ - virtual PackedFunc JIT(const CCacheKey& key) = 0; + virtual PackedFunc JIT(const CCacheKey& key, const Array& buffers = {}) = 0; /*! * \brief Lower the shape function. * \param key The key to the cached function. @@ -269,6 +272,10 @@ inline size_t CCacheKeyNode::Hash() const { inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const { if (Hash() != other->Hash()) return false; + if (other->buffers.size() != this->buffers.size()) return false; + for (size_t i = 0; i < other->buffers.size(); i++) { + if (!tvm::StructuralEqual()(other->buffers[i], this->buffers[i])) return false; + } return this->target->str() == other->target->str() && tvm::StructuralEqual()(this->source_func, other->source_func); } diff --git a/src/relay/backend/graph_plan_memory.cc b/src/relay/backend/graph_plan_memory.cc index 15173c2c79db..0751c3ea2ff3 100644 --- a/src/relay/backend/graph_plan_memory.cc +++ b/src/relay/backend/graph_plan_memory.cc @@ -26,13 +26,16 @@ #include #include #include +#include #include "../../support/arena.h" namespace tvm { namespace relay { -using IntegerArray = Array; +using TargetsMap = Map; +using Texture2DShape = runtime::Texture2DShape; +constexpr auto Is2DStorage = runtime::IsTextureStorage; struct StorageToken { /*! \brief Reference counter */ @@ -46,6 +49,8 @@ struct StorageToken { int device_type{0}; /*! \brief The storage id */ int64_t storage_id{-1}; + /*! \brief The storage scope */ + std::string storage_scope{"global"}; }; class StorageAllocaBaseVisitor : public ExprVisitor { @@ -125,14 +130,48 @@ class StorageAllocaBaseVisitor : public ExprVisitor { virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0; }; +/*! + * \brief Collect the target specific tensor storage info for each expression's output. + * \param expr The expression. + * \param expr The device id map which can be used to infer device specific storage scope availability. + * \param expr The target mapping from device id to target. + * \return The device based storage mapping. + */ +Map> CollectStorageInfo(const Expr& expr, const Map& dev_map, const TargetsMap& target_map) { + auto less = [](Integer i, Integer j) { + auto i_imm = i.as(); + auto j_imm = j.as(); + ICHECK(i_imm && j_imm); + return i_imm->value < j_imm->value; + }; + std::set device_types(less); + for (auto& kv : target_map) { + device_types.insert(kv.first); + } + std::string ftarget_prefix = "relay.backend"; + for (auto& dev_id : device_types) { + Target target = target_map[dev_id]; + ftarget_prefix += ("." + target->kind->name); + if (Optional t_device = target->GetAttr("device")) { + ftarget_prefix += ("." + t_device.value()); + } + } + Map> storage_info = {}; + if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) { + storage_info = (*f)(expr, dev_map, target_map); + } + return storage_info; +} + class StorageAllocaInit : protected StorageAllocaBaseVisitor { public: explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {} /*! \return The internal token map */ std::unordered_map > GetInitTokenMap( - const Function& func) { + const Function& func, const TargetsMap& targets) { node_device_map_ = CollectDeviceInfo(func); + node_storage_map_ = CollectStorageInfo(func, node_device_map_, targets); this->Run(func); return std::move(token_map_); } @@ -143,15 +182,26 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { void CreateToken(const ExprNode* op, bool can_realloc) final { ICHECK(!token_map_.count(op)); std::vector tokens; + auto expr = GetRef(op); int device_type = - node_device_map_.count(GetRef(op)) ? node_device_map_[GetRef(op)]->value : 0; + node_device_map_.count(expr) ? node_device_map_[expr]->value : 0; + + Optional> storage_info; + if (node_storage_map_.count(GetRef(op))) { + storage_info = node_storage_map_[GetRef(op)]; + } + if (const auto* tuple_type = op->checked_type().as()) { - for (Type t : tuple_type->fields) { - const auto* ttype = t.as(); + if (storage_info.defined()) { ICHECK_EQ(tuple_type->fields.size(), storage_info.value().size()); } + for (size_t i = 0; i < tuple_type->fields.size(); i++) { + const auto* ttype = tuple_type->fields[i].as(); ICHECK(ttype); StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; + if (storage_info.defined()) { + token->storage_scope = storage_info.value()[i]; + } tokens.push_back(token); } } else { @@ -160,6 +210,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { StorageToken* token = arena_->make(); token->ttype = ttype; token->device_type = device_type; + if (storage_info.defined()) { + token->storage_scope = storage_info.value()[0]; + } tokens.push_back(token); } token_map_[op] = tokens; @@ -180,6 +233,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor { // allocator support::Arena* arena_; Map node_device_map_; + Map> node_storage_map_; }; class StorageAllocator : public StorageAllocaBaseVisitor { @@ -196,19 +250,20 @@ class StorageAllocator : public StorageAllocaBaseVisitor { } // Run storage allocation for a function. - Map > Plan(const Function& func) { - prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func); + Map Plan(const Function& func, const TargetsMap& targets) { + prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func, targets); this->Run(func); // The value of smap contains two integer arrays where the first array // contains the planned storage ids and the second holds the device types. - Map > smap; + Map smap; int num_annotated_nodes = 0; int num_nodes = 0; for (const auto& kv : token_map_) { std::vector storage_ids; std::vector device_types; + std::vector storage_scopes; for (StorageToken* tok : kv.second) { if (tok->device_type) { num_annotated_nodes++; @@ -216,8 +271,11 @@ class StorageAllocator : public StorageAllocaBaseVisitor { num_nodes++; storage_ids.push_back(tok->storage_id); device_types.push_back(tok->device_type); + storage_scopes.push_back(tok->storage_scope); } - smap.Set(GetRef(kv.first), Array({storage_ids, device_types})); + std::vector fields{ + Array{storage_ids}, Array{device_types}, Array{storage_scopes}}; + smap.Set(GetRef(kv.first), runtime::ADT::Tuple(fields)); } // Either all or none of the nodes should be annotated. if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) { @@ -237,7 +295,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor { ICHECK(it != prototype_.end()); std::vector tokens; for (StorageToken* tok : it->second) { - if (can_realloc) { + if (can_realloc && tok->storage_scope == "global") { tokens.push_back(Request(tok)); } else { // Allocate a new token, @@ -375,8 +433,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor { std::unordered_map > prototype_; }; -Map > GraphPlanMemory(const Function& func) { - return StorageAllocator().Plan(func); +Map GraphPlanMemory(const Function& func, const TargetsMap& targets) { + return StorageAllocator().Plan(func, targets); } TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory); diff --git a/src/relay/backend/graph_runtime_codegen.cc b/src/relay/backend/graph_runtime_codegen.cc index 7ed150495104..dd6f74247531 100644 --- a/src/relay/backend/graph_runtime_codegen.cc +++ b/src/relay/backend/graph_runtime_codegen.cc @@ -49,7 +49,7 @@ using GraphAttrs = std::unordered_map; using GraphObjectPtr = std::shared_ptr; using GraphInputObjectPtr = std::shared_ptr; using GraphOpObjectPtr = std::shared_ptr; -using TargetsMap = std::unordered_map; +using TargetsMap = Map; /*! \brief Lowered outputs */ struct LoweredOutput { @@ -191,7 +191,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorparams) { auto node_ptr = GraphInputNode::make_node_ptr(param->name_hint(), GraphAttrs()); @@ -250,16 +250,22 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_info; - for (auto& v : storage_device_info[0]) { + for (auto& v : Downcast(storage_device_info[0])) { storage_info.push_back(v->value); } node->attrs_["storage_id"] = std::move(storage_info); + // storage scope + std::vector storage_scope; + for (auto& v : Downcast>(storage_device_info[2])) { + storage_scope.push_back(std::string(v)); + } + node->attrs_["storage_scope"] = std::move(storage_scope); // type std::vector device_types; - for (auto& v : storage_device_info[1]) { + for (auto& v : Downcast(storage_device_info[1])) { device_types.push_back(v->value); } size_t num_unknown_devices = std::count(device_types.begin(), device_types.end(), 0); @@ -320,7 +326,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; + param_storage_ids_[name] = Downcast(storage_device_map_[expr][0])[0]->value; params_[name] = op->data; return to_return; } @@ -382,7 +388,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorvalue; + auto call_dev_type = Downcast(device_type)[0]->value; // Normal Relay Function if (targets_.size() == 1) { // homogeneous execution. @@ -401,8 +407,18 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator buffers; + std::string ftarget_prefix = "relay.backend." + target->kind->name; + if (Optional t_device = target->GetAttr("device")) { + ftarget_prefix += ("." + t_device.value()); + } + if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectBufferBinds")) { + buffers = (*f)(GetRef(op), storage_device_map_); + } + + CCacheKey key = (*pf0)(func, target, buffers); + CachedFunc lowered_func = (*pf1)(compile_engine_, key, buffers); if (!lowered_funcs_.count(target->str())) { lowered_funcs_[target->str()] = IRModule(Map({})); } @@ -472,12 +488,14 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator storage_ids; + std::vector storage_scopes; std::vector device_types; std::vector dltypes; std::vector node_row_ptr{0}; for (auto node : nodes_) { const auto& shape_vec = dmlc::get(node->attrs_["shape"]); const auto& storage_id = dmlc::get>(node->attrs_["storage_id"]); + const auto& storage_scope = dmlc::get>(node->attrs_["storage_scope"]); const auto& dtype_vec = dmlc::get>(node->attrs_["dtype"]); ICHECK_EQ(node->num_outputs_, shape_vec.size()); @@ -486,6 +504,7 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslatorattrs_.count("device_index")) { const auto& dev_types = dmlc::get>(node->attrs_["device_index"]); device_types.insert(device_types.end(), dev_types.begin(), dev_types.end()); @@ -501,6 +520,8 @@ class GraphRuntimeCodegen : public backend::MemoizedExprTranslator params_; std::unordered_map param_storage_ids_; /*! \brief plan memory of device result */ - Map> storage_device_map_; + Map storage_device_map_; /*! \brief lowered funcs */ std::unordered_map lowered_funcs_; /*! \brief name map */ @@ -564,15 +585,9 @@ class GraphRuntimeCodegenModule : public runtime::ModuleNode { if (name == "init") { return PackedFunc([sptr_to_self, this](TVMArgs args, TVMRetValue* rv) { ICHECK_EQ(args.num_args, 2) << "The expected of arguments are: " - << "runtime::Module mod and Map targets"; + << "runtime::Module mod and Map targets"; void* mod = args[0]; - Map tmp = args[1]; - TargetsMap targets; - for (const auto& it : tmp) { - auto dev_type = it.first.as(); - ICHECK(dev_type); - targets[dev_type->value] = it.second; - } + TargetsMap targets = args[1]; codegen_ = std::make_shared(reinterpret_cast(mod), targets); }); diff --git a/src/runtime/graph/graph_runtime.cc b/src/runtime/graph/graph_runtime.cc index 6c51e711aef1..eccb4c9a9dfd 100644 --- a/src/runtime/graph/graph_runtime.cc +++ b/src/runtime/graph/graph_runtime.cc @@ -279,6 +279,7 @@ void GraphRuntime::SetupStorage() { // Find the maximum space size. for (size_t i = 0; i < attrs_.shape.size(); ++i) { int storage_id = attrs_.storage_id[i]; + std::string storage_scope = attrs_.storage_scope[i]; // Use the fallback device if no device index is available. int device_type = static_cast(ctxs_[0].device_type); if (!attrs_.device_index.empty()) { @@ -315,6 +316,7 @@ void GraphRuntime::SetupStorage() { pool_entry[sid].param_data_entry = i; pool_entry[sid].size = std::max(pool_entry[sid].size, bytes); pool_entry[sid].device_type = device_type; + pool_entry[sid].scope = storage_scope; } // Allocate the space. @@ -330,7 +332,8 @@ void GraphRuntime::SetupStorage() { } else { std::vector shape; shape.push_back(static_cast(pit.size + 3) / 4); - storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx)); + Optional scope = String(pit.scope); + storage_pool_.push_back(NDArray::Empty(shape, DLDataType{kDLFloat, 32, 1}, ctx, scope)); } } diff --git a/src/runtime/graph/graph_runtime.h b/src/runtime/graph/graph_runtime.h index a1e2ee3b5d74..dce2a9bca344 100644 --- a/src/runtime/graph/graph_runtime.h +++ b/src/runtime/graph/graph_runtime.h @@ -184,6 +184,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { int device_type; int param_data_entry; NDArray linked_param; + std::string scope; // PoolEntry(int s, int dev_type, void* pre_linked_param) : // size(s), device_type(dev_type), pre_linked_param(std::move(pre_linked_param)) {} }; @@ -274,6 +275,7 @@ class TVM_DLL GraphRuntime : public ModuleNode { std::vector storage_id; std::vector device_index; std::vector dltype; + std::vector storage_scope; std::vector> shape; // The graph attribute fields. void Load(dmlc::JSONReader* reader) { @@ -299,6 +301,15 @@ class TVM_DLL GraphRuntime : public ModuleNode { reader->Read(&storage_id); ICHECK(!reader->NextArrayItem()); bitmask |= 2; + } else if (key == "storage_scope") { + reader->BeginArray(); + ICHECK(reader->NextArrayItem()); + reader->Read(&type); + ICHECK_EQ(type, "list_str"); + ICHECK(reader->NextArrayItem()); + reader->Read(&storage_scope); + ICHECK(!reader->NextArrayItem()); + bitmask |= 1; } else if (key == "shape") { reader->BeginArray(); ICHECK(reader->NextArrayItem());