Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 64 additions & 173 deletions src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,14 @@
#include <string>
#include <vector>

#include "compile_engine.h"
#include "te_compiler.h"
#include "utils.h"

namespace tvm {
namespace relay {
namespace backend {

using IntegerArray = Array<Integer>;
using TargetsMap = std::unordered_map<int, Target>;
using StorageMap =
std::unordered_map<Expr, StorageInfo, runtime::ObjectPtrHash, runtime::ObjectPtrEqual>;

Expand Down Expand Up @@ -287,7 +286,6 @@ class AOTExecutorCodegen : public ExprVisitor {
void CreateFuncCall(Call call, std::string func_name) {
tvm::Array<PrimExpr> args{tvm::tir::StringImm(func_name)};
std::vector<tir::Stmt> create_func_call_stmts;

// Pack the inputs
for (Expr arg : call->args) {
if (params_by_expr_.find(arg) != params_by_expr_.end()) {
Expand Down Expand Up @@ -365,155 +363,21 @@ class AOTExecutorCodegen : public ExprVisitor {
return ss.str();
}

/*!
* \brief Update the "main" control function's metadata
*
* \param func The main function that contains calls to operator tir primitive functions
*/
void UpdateMainWorkspaceSize(const tir::PrimFunc& primfunc, const relay::Function& func) {
auto workspace_byte_alignment = target_host_->GetAttr<Integer>("workspace-byte-alignment")
.value_or(tvm::runtime::kDefaultWorkspaceAlignment);
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
// Populate FunctionInfo
auto fi_node = make_object<FunctionInfoNode>();
// Initialize all target workspaces to zero
for (const auto& kv : targets_) {
auto tgt = kv.second;
fi_node->workspace_sizes.Set(tgt, 0);
}
fi_node->workspace_sizes.Set(target_host_, workspace_size);
fi_node->relay_primfuncs.Set(target_host_, func);

int64_t io_size = 0;
for (const auto& input : input_vars_) {
io_size += CalculateRelayExprSizeBytes(input->checked_type());
}
io_size += CalculateRelayExprSizeBytes(func->body->checked_type());
fi_node->io_sizes.Set(target_host_, io_size);

int64_t const_size = 0;
for (const auto& kv : params_by_expr_) {
const_size += CalculateRelayExprSizeBytes(kv.first->checked_type());
}
fi_node->constant_sizes.Set(target_host_, const_size);
function_metadata_.Set(String(runtime::symbol::tvm_module_main), FunctionInfo(fi_node));
}

/*!
* \brief Update the function metadata for a given cached function and its relay
* primitive function.
*
* \param cfunc The cached function as provided the by the compile engine
* \param relay_func The source relay primitive function
* \param relay_target The target associated with relay primitive function
*/
void UpdateFunctionMetadata(const CachedFunc& cfunc, const Function& relay_func,
const Target& relay_target) {
auto fi_node = make_object<FunctionInfoNode>();
for (const auto& kv : cfunc->funcs->functions) {
auto primfunc = Downcast<tir::PrimFunc>(kv.second);
auto workspace_byte_alignment =
target_host_->GetAttr<Integer>("workspace-byte-alignment").value_or(16);
Integer workspace_size = CalculateWorkspaceBytes(primfunc, workspace_byte_alignment);
Target primfunc_target = relay_target;
if (primfunc->attrs->dict.count("target")) {
primfunc_target = Downcast<Target>(primfunc->attrs->dict["target"]);
}
fi_node->workspace_sizes.Set(primfunc_target, workspace_size);
// Calculating size for I/O
for (auto const& param : primfunc->params) {
auto p_shape = primfunc->buffer_map[param]->shape;
int num_of_elements = 1;
for (const auto& dim_index_expr : p_shape) {
if (dim_index_expr->IsInstance<IntImmNode>()) {
num_of_elements *= dim_index_expr.as<IntImmNode>()->value;
} else {
// If shape is dynamic, we cannot calculate workspace in compile time.
num_of_elements = 0;
}
}
int element_size = primfunc->buffer_map[param]->dtype.bytes();
fi_node->io_sizes.Set(primfunc_target, element_size * num_of_elements);
}
fi_node->constant_sizes.Set(primfunc_target, 0);
fi_node->tir_primfuncs.Set(primfunc_target, primfunc);
fi_node->relay_primfuncs.Set(primfunc_target, relay_func);
}
function_metadata_.Set(cfunc->prim_fn_var->name_hint, FunctionInfo(fi_node));
}

void VisitExpr_(const CallNode* op) override {
// Descend the call tree
for (auto arg : op->args) {
VisitExpr(arg);
}

Expr expr = GetRef<Expr>(op);
Function func;
if (op->op.as<OpNode>()) {
LOG(FATAL) << "Operators should be transformed away; try applying"
<< "the fuse_ops transformation to the expression.";
} else if (op->op.as<GlobalVarNode>()) {
LOG(FATAL) << "Not implemented";
} else if (op->op.as<FunctionNode>()) {
func = GetRef<Function>(op->op.as<FunctionNode>());
GlobalVar node = GetRef<GlobalVar>(op->op.as<GlobalVarNode>());
CreateFuncCall(GetRef<Call>(op), node->name_hint);
} else {
LOG(FATAL) << "TVM runtime does not support calls to " << op->op->GetTypeKey();
}
if (!func->HasNonzeroAttr(attr::kPrimitive)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should maintain this check -- maybe inside CreateFuncCall ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Post-lowering every function is essentially a GlobalVar so this path was never called. If there's a test case that shows this I can re-introduce it.

Copy link
Contributor

@manupak manupak Aug 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an assumption turned assertion that all calls are made primitive functions. This guarantees the relay lowering is done, before the respective executor codegen is invoked.

But I see your point -- its was never checked as it should've been. Maybe its worth checking the function attached to GlobalVar has this property?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Everything is already lowered at this point as it's been through LowerTE before this runs, so we don't have to make the assumption - it's guaranteed 😸

I'd also suggest that we don't add defensive code which we can't craft a way to invoke?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see.

So its kind of passed onto LowerTE.

Yeah, if the check itself is not invoked then no point of having it there.

LGTM.

LOG(FATAL) << "TVM only support calls to primitive functions "
<< "(i.e functions composed of fusable operator invocations)";
}

Target target;

// Handle external function
if (func->GetAttr<String>(attr::kCompiler).defined()) {
target = Target("ext_dev");
CCacheKey key = CCacheKey(func, target);
CachedFunc ext_func = compile_engine_->Lower(key, mod_name_);
ICHECK(ext_func.defined()) << "External function is not defined.";
UpdateConstants(func, &params_);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), ext_func->prim_fn_var->name_hint);
return;
}

ICHECK_GE(storage_device_map_.count(expr), 0);
StorageInfo& sinfo = storage_device_map_[expr];
auto call_dev_type = sinfo->device_types[0];
// Normal Relay Function
if (targets_.size() == 1) {
// homogeneous execution.
const auto& it = targets_.begin();
target = (*it).second;
} else {
// heterogeneous execution.
std::string call_dev_name;
if (call_dev_type == 0) {
call_dev_name = "llvm";
} else {
call_dev_name = runtime::DeviceName(call_dev_type);
}
if (targets_.count(call_dev_type) == 0) {
LOG(FATAL) << "No target is provided for device " << call_dev_name;
}
target = targets_[call_dev_type];
}

CCacheKey key = CCacheKey(func, target);
CachedFunc lowered_func = compile_engine_->Lower(key, mod_name_);

if (!lowered_funcs_.count(target->str())) {
lowered_funcs_[target->str()] = IRModule(Map<GlobalVar, BaseFunc>({}));
}
lowered_funcs_[target->str()]->Update(lowered_func->funcs);
// Update function metadata via looking at all primfuncs
UpdateFunctionMetadata(lowered_func, func, target);

// Generate the TIR function call
CreateFuncCall(GetRef<Call>(op), lowered_func->prim_fn_var->name_hint);
}

void VisitExpr_(const VarNode* op) override {
Expand Down Expand Up @@ -598,7 +462,7 @@ class AOTExecutorCodegen : public ExprVisitor {
// Create the main PrimFunc to execute the graph. Please note that
// the packed function calls don't pack their arguments. The AOT
// runner function needs to be legalized by the LegalizePackedCalls pass.
tir::PrimFunc CreateMainFunc(unsigned int relay_params) {
tir::PrimFunc CreateMainFunc(String mod_name, unsigned int relay_params) {
tir::Stmt body = tir::SeqStmt(stmts_);

// Allocate the sids
Expand Down Expand Up @@ -637,7 +501,7 @@ class AOTExecutorCodegen : public ExprVisitor {
// Define the PrimFunc attributes
Map<String, ObjectRef> dict_attrs;
String run_func_name =
runtime::get_name_mangled(mod_name_, runtime::symbol::tvm_run_func_suffix);
runtime::get_name_mangled(mod_name, runtime::symbol::tvm_run_func_suffix);
dict_attrs.Set("global_symbol", run_func_name);
dict_attrs.Set("runner_function", Bool(true));

Expand All @@ -654,7 +518,7 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief input and output variables belonging to the main function signature */
Array<tir::Var> main_signature_;
/*! \brief target device */
TargetsMap targets_;
tec::TargetMap targets_;
/*! \brief target host */
Target target_host_;
/*!
Expand Down Expand Up @@ -684,35 +548,70 @@ class AOTExecutorCodegen : public ExprVisitor {
/*! \brief mapping sid -> tir::Var */
std::unordered_map<int, te::Var> sids_table_;
/*! \brief lowered funcs */
std::unordered_map<std::string, IRModule> lowered_funcs_;
/*! \brief lowered funcs */
Map<String, FunctionInfo> function_metadata_;
/*! \brief compile engine */
CompileEngine compile_engine_;
/*! \brief the set of statements that make the program */
std::vector<tir::Stmt> stmts_;
/*! \brief the list of return sids (note that the function might return more then one output */
std::vector<int> return_sid_;
/*! \brief the module name we use to mangle the function names */
String mod_name_;

public:
AOTExecutorCodegen(runtime::Module* mod, const TargetsMap& targets, Target target_host)
AOTExecutorCodegen(runtime::Module* mod, const tec::TargetMap& targets, Target target_host)
: mod_(mod),
targets_(targets),
target_host_(target_host),
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))),
compile_engine_(CompileEngine::Global()) {}
use_unpacked_api_(target_host->GetAttr<Bool>("unpacked-api").value_or(Bool(false))) {}

LoweredOutput Codegen(relay::Function func, String mod_name) {
auto aot_allocator = AOTOnDemandAllocator();
aot_allocator.Run(func);

// Retrieve the storage map
storage_device_map_ = aot_allocator.GetStorageMap();
mod_name_ = mod_name;
// Pre-lowering storage map and memory plan
StorageMap initial_storage_map = aot_allocator.GetStorageMap();
StaticMemoryPlan memory_plan(initial_storage_map);

// Build a map from each operation to device.
tec::DeviceMap device_context_map;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbs-octoml I don't think @Mousius needs to do this in this patch, but this is where we should split device planning and storage planning. I think we can remove the need to pre-storage plan at all if we can obtain the device information pre-lowering, then storage plan after the lowering.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it.

for (const auto& it : memory_plan->expr_to_storage_info) {
auto expr = it.first;
auto storage_info = it.second;
auto device_types = storage_info->device_types;
// CHECK_EQ(device_types.size(), 1);
tvm::Device dev;
dev.device_id = 0;
dev.device_type = device_types[0];
device_context_map.insert({expr, dev});
}

// This first phase moves from implicit use of compile engine,
// to instead explicitly lowering the incoming IRModule, and then
// performing the preexisting AOT executor code generation phase.
IRModule mod = IRModule::FromExpr(func);
auto lowered_module = tec::LowerTE(
mod, targets_, device_context_map, memory_plan, mod_name, [this](Function func) {
// We need to maintain the constant map for external
// functions so we pass this processing function which
// allows us to process each function as we lower it.
if (func->GetAttr<String>(attr::kCompiler).defined()) {
UpdateConstants(func, &params_);
}

// TODO(@areusch, @jroesch): We should refactor this to
// execute as a further pass, instead writing data to the
// lowering process directly.
tec::UpdateFunctionMetadata(func, this->function_metadata_);
});

for (auto input : func->params) {
function_metadata_.Set(runtime::symbol::tvm_module_main, lowered_module.main_func_info);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mbs-octoml I think my point is we should be able to do like lowered_module.main_module["main"].GetAttr("function_info"). Maybe your point is to clean this up in future patch after landing this?

auto lowered_main = lowered_module.main_module->Lookup("main");
auto lowered_main_func = GetRef<Function>(lowered_main.as<FunctionNode>());

// Post-lowering storage map for writing main func - this should be the same map as previously
// created, just referencing the new expressions created from lowering
auto new_allocator = AOTOnDemandAllocator();
new_allocator.Run(lowered_main_func);
storage_device_map_ = new_allocator.GetStorageMap();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feel free to leave a TODO(mbs) to remove this reconstruction since I'm trying to replace these Expr->Storage and Expr->Device side maps with attrs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's ok without an explicit TODO here as it makes sense to replan the allocations? I appreciate it'll get improved later 😸

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, but then can you comment the storage map should be morally the same as the original, just with the keys updated to follow along with the rewritten primitive calls. Or at least that's what I think should be happening, is that right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yip, that's correct, I've updated the comment to clarify that - what do you think now?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm


for (auto input : lowered_main_func->params) {
input_vars_.push_back(input);
main_signature_.push_back(tir::Var("input", DataType::Handle()));
}
Expand All @@ -732,13 +631,12 @@ class AOTExecutorCodegen : public ExprVisitor {
main_signature_.push_back(tir::Var("output", DataType::Handle()));
}

VisitExpr(func->body);
VisitExpr(lowered_main_func->body);

// Create the runner function. Please note that the function is not legal yet
// because the packed calls arguments are not wrapped in TVMValues. To make this happen we need
// to run the LegalizePackedCalls pass.
auto prim_func = CreateMainFunc(func->params.size());
UpdateMainWorkspaceSize(prim_func, func);
auto prim_func = CreateMainFunc(mod_name, lowered_main_func->params.size());
LoweredOutput ret;

ret.params = std::unordered_map<std::string, std::pair<int, const tvm::runtime::NDArray>>();
Expand All @@ -748,17 +646,7 @@ class AOTExecutorCodegen : public ExprVisitor {
std::make_pair(static_cast<int>(param_storage_ids_[param.first]), param.second)));
}

for (auto& kv : lowered_funcs_) {
if (ret.lowered_funcs.count(kv.first) == 0) {
ret.lowered_funcs.Set(kv.first, IRModule(Map<GlobalVar, BaseFunc>({})));
}
auto& mod = ret.lowered_funcs[kv.first];
mod->Update(kv.second);
ret.lowered_funcs.Set(kv.first, mod);
}
ret.external_mods = compile_engine_->LowerExternalFunctions();

// Build the TIR IRModule
// Build the TIR IRModule for the AOT function
Map<GlobalVar, BaseFunc> symbol_map;
symbol_map.Set(GlobalVar(::tvm::runtime::symbol::tvm_run_func_suffix), prim_func);
IRModule mod_run(symbol_map);
Expand All @@ -774,14 +662,17 @@ class AOTExecutorCodegen : public ExprVisitor {
mod_run = pack_calls(mod_run);
}

// Update the lowered functions
ret.function_metadata = std::move(function_metadata_);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible for us to remove the specialized main functions FunctionInfo field in LoweredModule? They should all be uniformly set inside of the lowering, and I think the specialized main is a hold over from Lily's code that could be cleaned up.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That field is still useful, and we can unify it between Graph and AOT to refactor altogether later; I've just pushed up a change to use the main_func_info attribute on the LoweredModule rather than the AOT bespoke way of processing this information so it can be refactored in unison later.


ret.lowered_funcs = lowered_module.per_target_module;
ret.external_mods = lowered_module.external_mods;

auto target_host_str = target_host_->str();
if (ret.lowered_funcs.find(target_host_str) != ret.lowered_funcs.end()) {
ret.lowered_funcs[target_host_str]->Update(mod_run);
} else {
ret.lowered_funcs.Set(target_host_str, mod_run);
}
ret.function_metadata = std::move(function_metadata_);

std::vector<String> input_var_names(input_vars_.size());
std::transform(input_vars_.begin(), input_vars_.end(), input_var_names.begin(),
Expand Down Expand Up @@ -845,15 +736,15 @@ class AOTExecutorCodegenModule : public runtime::ModuleNode {

private:
void init(void* mod, Map<Integer, tvm::Target> tmp) {
TargetsMap targets;
tec::TargetMap targets;
Target target_host;
for (const auto& it : tmp) {
auto dev_type = it.first.as<tir::IntImmNode>();
if (!target_host.defined() && it.second->kind->device_type == kDLCPU) {
target_host = it.second;
}
ICHECK(dev_type);
targets[dev_type->value] = it.second;
targets[static_cast<DLDeviceType>(dev_type->value)] = it.second;
}
codegen_ = std::make_shared<AOTExecutorCodegen>(reinterpret_cast<runtime::Module*>(mod),
targets, target_host);
Expand Down