Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,13 @@ def get_binds(args, compact=False, binds=None):
arg_list: list
The list of symbolic buffers of arguments.
"""
binds = {} if binds is None else binds.copy()

if isinstance(binds, container.Map):
binds = {k : v for (k, v) in binds.items()}
elif isinstance(binds, dict):
binds = binds.copy()
elif binds == None:
binds = {}
arg_list = []
for x in args:
if isinstance(x, tensor.Tensor):
Expand Down
13 changes: 9 additions & 4 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@


@tvm._ffi.register_func("relay.backend.lower")
def lower(sch, inputs, func_name, source_func):
def lower(sch, inputs, func_name, source_func, binds=None):
"""Backend function for lowering.

Parameters
Expand All @@ -37,6 +37,11 @@ def lower(sch, inputs, func_name, source_func):
source-func : tvm.relay.Function
The source function to be lowered.

binds : dict of :any:`Tensor` to :any:`Buffer`, optional
Dictionary that maps the Tensor to Buffer which specified the data layout
requirement of the function. By default, a new compact buffer is created
for each tensor in the argument.

Returns
-------
mod : tvm.IRModule
Expand All @@ -46,7 +51,7 @@ def lower(sch, inputs, func_name, source_func):
import traceback

try:
f = tvm.driver.lower(sch, inputs, name=func_name)
f = tvm.driver.lower(sch, inputs, name=func_name, binds=binds)
# logging.debug("lower function %s", func_name)
# logging.debug("%s", _build.lower(sch, inputs, simple_mode=True))
except Exception:
Expand All @@ -59,7 +64,7 @@ def lower(sch, inputs, func_name, source_func):


@tvm._ffi.register_func("relay.backend.build")
def build(mod, target, target_host=None):
def build(mod, target, target_host=None, binds=None):
"""Backend build function.

Parameters
Expand All @@ -80,7 +85,7 @@ def build(mod, target, target_host=None):
"""
if target_host == "":
target_host = None
return tvm.driver.build(mod, target=target, target_host=target_host)
return tvm.driver.build(mod, target=target, target_host=target_host, binds=binds)


@tvm._ffi.register_func("relay._tensor_value_repr")
Expand Down
30 changes: 21 additions & 9 deletions src/relay/backend/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,11 @@ LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation im
data_ = std::move(n);
}

CCacheKey::CCacheKey(Function source_func, Target target) {
CCacheKey::CCacheKey(Function source_func, Target target, Array<tir::Buffer> buffers) {
auto n = make_object<CCacheKeyNode>();
n->source_func = std::move(source_func);
n->target = std::move(target);
n->buffers = std::move(buffers);
data_ = std::move(n);
}

Expand Down Expand Up @@ -612,11 +613,12 @@ class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>>
class CompileEngineImpl : public CompileEngineNode {
public:
// Lower the function.
CachedFunc Lower(const CCacheKey& key) { return LowerInternal(key)->cached_func; }
CachedFunc Lower(const CCacheKey& key, const Array<tir::Buffer>& buffers) {
return LowerInternal(key, buffers)->cached_func; }

// For now, build one module per function.
PackedFunc JIT(const CCacheKey& key) final {
CCacheValue value = LowerInternal(key);
PackedFunc JIT(const CCacheKey& key, const Array<tir::Buffer>& buffers) final {
CCacheValue value = LowerInternal(key, buffers);
if (value->packed_func != nullptr) return value->packed_func;
// build the function.
tvm::runtime::Module m;
Expand Down Expand Up @@ -711,7 +713,7 @@ class CompileEngineImpl : public CompileEngineNode {

private:
// implement lowered func
CCacheValue LowerInternal(const CCacheKey& key) {
CCacheValue LowerInternal(const CCacheKey& key, const Array<tir::Buffer>& buffers = {}) {
std::lock_guard<std::mutex> lock(mutex_);
CCacheValue value;
auto it = cache_.find(key);
Expand Down Expand Up @@ -762,9 +764,19 @@ class CompileEngineImpl : public CompileEngineNode {
for (te::Tensor arg : cache_node->outputs) {
all_args.push_back(arg);
}

// build the bind map
Map<te::Tensor, tir::Buffer> binds;
if (buffers.size() == all_args.size()) {
for (size_t i = 0; i < all_args.size(); i++) {
auto& arg = all_args[i];
binds.Set(arg, buffers[i]);
}
}

// lower the function
if (const auto* f = runtime::Registry::Get("relay.backend.lower")) {
cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func);
cache_node->funcs = (*f)(cfunc->schedule, all_args, cache_node->func_name, key->source_func, binds);
} else {
using tvm::transform::PassContext;
With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
Expand Down Expand Up @@ -863,8 +875,8 @@ TVM_REGISTER_GLOBAL("relay.backend._make_LoweredOutput")
});

TVM_REGISTER_GLOBAL("relay.backend._make_CCacheKey")
.set_body_typed([](Function source_func, Target target) {
return CCacheKey(source_func, target);
.set_body_typed([](Function source_func, Target target, Array<tir::Buffer> buffers = {}) {
return CCacheKey(source_func, target, buffers);
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineGlobal").set_body_typed([]() {
Expand All @@ -876,7 +888,7 @@ TVM_REGISTER_GLOBAL("relay.backend._CompileEngineClear").set_body_typed([](Compi
});

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLower")
.set_body_typed([](CompileEngine self, CCacheKey key) { return self->Lower(key); });
.set_body_typed([](CompileEngine self, CCacheKey key, Array<tir::Buffer> buffers) { return self->Lower(key, buffers); });

TVM_REGISTER_GLOBAL("relay.backend._CompileEngineLowerShapeFunc")
.set_body_typed([](CompileEngine self, CCacheKey key) { return self->LowerShapeFunc(key); });
Expand Down
13 changes: 10 additions & 3 deletions src/relay/backend/compile_engine.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ class CCacheKeyNode : public Object {
Function source_func;
/*! \brief The hardware target.*/
Target target;
/*! \brief Any buffers bound to the source function. */
Array<tir::Buffer> buffers;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("source_func", &source_func);
Expand Down Expand Up @@ -148,8 +150,9 @@ class CCacheKey : public ObjectRef {
* \brief The constructor
* \param source_func The source function.
* \param target The target device.
* \param buffers Optional bound buffers
*/
TVM_DLL CCacheKey(Function source_func, Target target);
TVM_DLL CCacheKey(Function source_func, Target target, Array<tir::Buffer> buffers = {});

const CCacheKeyNode* operator->() const { return static_cast<const CCacheKeyNode*>(get()); }
// comparator
Expand Down Expand Up @@ -201,13 +204,13 @@ class CompileEngineNode : public Object {
* \param key The key to the cached function.
* \return The result.
*/
virtual CachedFunc Lower(const CCacheKey& key) = 0;
virtual CachedFunc Lower(const CCacheKey& key, const Array<tir::Buffer>& buffers = {}) = 0;
/*!
* \brief Just in time compile to get a PackedFunc.
* \param key The key to the cached function.
* \return The result.
*/
virtual PackedFunc JIT(const CCacheKey& key) = 0;
virtual PackedFunc JIT(const CCacheKey& key, const Array<tir::Buffer>& buffers = {}) = 0;
/*!
* \brief Lower the shape function.
* \param key The key to the cached function.
Expand Down Expand Up @@ -269,6 +272,10 @@ inline size_t CCacheKeyNode::Hash() const {

inline bool CCacheKeyNode::Equal(const CCacheKeyNode* other) const {
if (Hash() != other->Hash()) return false;
if (other->buffers.size() != this->buffers.size()) return false;
for (size_t i = 0; i < other->buffers.size(); i++) {
if (!tvm::StructuralEqual()(other->buffers[i], this->buffers[i])) return false;
}
return this->target->str() == other->target->str() &&
tvm::StructuralEqual()(this->source_func, other->source_func);
}
Expand Down
82 changes: 70 additions & 12 deletions src/relay/backend/graph_plan_memory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,16 @@
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
#include <tvm/tir/op.h>
#include <tvm/target/target.h>

#include "../../support/arena.h"

namespace tvm {
namespace relay {

using IntegerArray = Array<Integer>;
using TargetsMap = Map<Integer, Target>;
using Texture2DShape = runtime::Texture2DShape<int64_t>;
constexpr auto Is2DStorage = runtime::IsTextureStorage;

struct StorageToken {
/*! \brief Reference counter */
Expand All @@ -46,6 +49,8 @@ struct StorageToken {
int device_type{0};
/*! \brief The storage id */
int64_t storage_id{-1};
/*! \brief The storage scope */
std::string storage_scope{"global"};
};

class StorageAllocaBaseVisitor : public ExprVisitor {
Expand Down Expand Up @@ -125,14 +130,48 @@ class StorageAllocaBaseVisitor : public ExprVisitor {
virtual void CreateToken(const ExprNode* op, bool can_realloc) = 0;
};

/*!
* \brief Collect the target specific tensor storage info for each expression's output.
* \param expr The expression.
* \param expr The device id map which can be used to infer device specific storage scope availability.
* \param expr The target mapping from device id to target.
* \return The device based storage mapping.
*/
Map<Expr, Array<String>> CollectStorageInfo(const Expr& expr, const Map<Expr, Integer>& dev_map, const TargetsMap& target_map) {
auto less = [](Integer i, Integer j) {
auto i_imm = i.as<tir::IntImmNode>();
auto j_imm = j.as<tir::IntImmNode>();
ICHECK(i_imm && j_imm);
return i_imm->value < j_imm->value;
};
std::set<Integer, decltype(less)> device_types(less);
for (auto& kv : target_map) {
device_types.insert(kv.first);
}
std::string ftarget_prefix = "relay.backend";
for (auto& dev_id : device_types) {
Target target = target_map[dev_id];
ftarget_prefix += ("." + target->kind->name);
if (Optional<String> t_device = target->GetAttr<String>("device")) {
ftarget_prefix += ("." + t_device.value());
}
}
Map<Expr, Array<String>> storage_info = {};
if (const auto* f = runtime::Registry::Get(ftarget_prefix + "._CollectStorageInfo")) {
storage_info = (*f)(expr, dev_map, target_map);
}
return storage_info;
}

class StorageAllocaInit : protected StorageAllocaBaseVisitor {
public:
explicit StorageAllocaInit(support::Arena* arena) : arena_(arena) {}

/*! \return The internal token map */
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > GetInitTokenMap(
const Function& func) {
const Function& func, const TargetsMap& targets) {
node_device_map_ = CollectDeviceInfo(func);
node_storage_map_ = CollectStorageInfo(func, node_device_map_, targets);
this->Run(func);
return std::move(token_map_);
}
Expand All @@ -143,15 +182,26 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
void CreateToken(const ExprNode* op, bool can_realloc) final {
ICHECK(!token_map_.count(op));
std::vector<StorageToken*> tokens;
auto expr = GetRef<Expr>(op);
int device_type =
node_device_map_.count(GetRef<Expr>(op)) ? node_device_map_[GetRef<Expr>(op)]->value : 0;
node_device_map_.count(expr) ? node_device_map_[expr]->value : 0;

Optional<Array<String>> storage_info;
if (node_storage_map_.count(GetRef<Expr>(op))) {
storage_info = node_storage_map_[GetRef<Expr>(op)];
}

if (const auto* tuple_type = op->checked_type().as<TupleTypeNode>()) {
for (Type t : tuple_type->fields) {
const auto* ttype = t.as<TensorTypeNode>();
if (storage_info.defined()) { ICHECK_EQ(tuple_type->fields.size(), storage_info.value().size()); }
for (size_t i = 0; i < tuple_type->fields.size(); i++) {
const auto* ttype = tuple_type->fields[i].as<TensorTypeNode>();
ICHECK(ttype);
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
token->device_type = device_type;
if (storage_info.defined()) {
token->storage_scope = storage_info.value()[i];
}
tokens.push_back(token);
}
} else {
Expand All @@ -160,6 +210,9 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
StorageToken* token = arena_->make<StorageToken>();
token->ttype = ttype;
token->device_type = device_type;
if (storage_info.defined()) {
token->storage_scope = storage_info.value()[0];
}
tokens.push_back(token);
}
token_map_[op] = tokens;
Expand All @@ -180,6 +233,7 @@ class StorageAllocaInit : protected StorageAllocaBaseVisitor {
// allocator
support::Arena* arena_;
Map<Expr, Integer> node_device_map_;
Map<Expr, Array<String>> node_storage_map_;
};

class StorageAllocator : public StorageAllocaBaseVisitor {
Expand All @@ -196,28 +250,32 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
}

// Run storage allocation for a function.
Map<Expr, Array<IntegerArray> > Plan(const Function& func) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func);
Map<Expr, runtime::ADT> Plan(const Function& func, const TargetsMap& targets) {
prototype_ = StorageAllocaInit(&arena_).GetInitTokenMap(func, targets);
this->Run(func);

// The value of smap contains two integer arrays where the first array
// contains the planned storage ids and the second holds the device types.
Map<Expr, Array<IntegerArray> > smap;
Map<Expr, runtime::ADT> smap;
int num_annotated_nodes = 0;
int num_nodes = 0;

for (const auto& kv : token_map_) {
std::vector<Integer> storage_ids;
std::vector<Integer> device_types;
std::vector<String> storage_scopes;
for (StorageToken* tok : kv.second) {
if (tok->device_type) {
num_annotated_nodes++;
}
num_nodes++;
storage_ids.push_back(tok->storage_id);
device_types.push_back(tok->device_type);
storage_scopes.push_back(tok->storage_scope);
}
smap.Set(GetRef<Expr>(kv.first), Array<IntegerArray>({storage_ids, device_types}));
std::vector<ObjectRef> fields{
Array<Integer>{storage_ids}, Array<Integer>{device_types}, Array<String>{storage_scopes}};
smap.Set(GetRef<Expr>(kv.first), runtime::ADT::Tuple(fields));
}
// Either all or none of the nodes should be annotated.
if (num_annotated_nodes != 0 && num_annotated_nodes != num_nodes) {
Expand All @@ -237,7 +295,7 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
ICHECK(it != prototype_.end());
std::vector<StorageToken*> tokens;
for (StorageToken* tok : it->second) {
if (can_realloc) {
if (can_realloc && tok->storage_scope == "global") {
tokens.push_back(Request(tok));
} else {
// Allocate a new token,
Expand Down Expand Up @@ -375,8 +433,8 @@ class StorageAllocator : public StorageAllocaBaseVisitor {
std::unordered_map<const ExprNode*, std::vector<StorageToken*> > prototype_;
};

Map<Expr, Array<IntegerArray> > GraphPlanMemory(const Function& func) {
return StorageAllocator().Plan(func);
Map<Expr, runtime::ADT> GraphPlanMemory(const Function& func, const TargetsMap& targets) {
return StorageAllocator().Plan(func, targets);
}

TVM_REGISTER_GLOBAL("relay.backend.GraphPlanMemory").set_body_typed(GraphPlanMemory);
Expand Down
Loading