From 8003ab515cd2e0d91d5bc2949576d5c6c2243f03 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 12 Apr 2022 10:14:31 +0000 Subject: [PATCH 1/4] [AOT] Enable A-Normal Form in the AOT executor The sequence of calls produced by the AOT executor codegen is arbitrary, especially in the presence of 'branchy' networks. This makes it difficult to analyze memory usage for each call. By running the ToANormalForm pass to insert a series of let bindings before the lowering and codegen stages, we can establish an ordering for the evaluation of the external calls, thus allowing reliable analysis of memory usage. Change-Id: Ic320b68cde83c96b228a8d1d2829a0e8ac7b768f --- src/relay/backend/aot_executor_codegen.cc | 173 +++++++++++++----- .../backend/contrib/cmsisnn/relay_to_tir.cc | 78 ++++++-- src/relay/backend/contrib/ethosu/codegen.cc | 79 ++++++-- .../example_target_hooks/relay_to_tir.cc | 60 +++++- src/relay/backend/te_compiler.cc | 2 + src/relay/transforms/target_hooks.cc | 30 ++- 6 files changed, 317 insertions(+), 105 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index c981f9d62b19..45c934a828c9 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -126,7 +126,14 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (const auto& param : func_node->params) { CreateStorage(param.get()); } - GetStorage(func_node->body); + StorageInfo si = GetStorage(func_node->body); + + // If the final expr could not be found it means it was let bound, + // manually add the var to the storage device map so it can be + // found by `UpdateMainWorkspaceSize`. + if (storage_device_map_.find(func_node->body) == storage_device_map_.end()) { + storage_device_map_[func_node->body] = si; + } } void VisitExpr_(const GlobalVarNode* op) final { @@ -168,7 +175,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void VisitExpr_(const IfNode* op) final { LOG(FATAL) << "if is not supported."; } void PreVisitLetBinding_(const Var& var, const Expr& value) final { - LOG(FATAL) << "let is not supported."; + VisitExpr(value); + let_bound_values_.Set(var, value); } private: @@ -215,11 +223,29 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { + Expr true_expr = expr; + + // Don't get storage for let nodes. + while (const auto* let_node = true_expr.as()) { + VisitExpr(true_expr); + true_expr = let_node->body; + } + + // Var nodes may be let bound, if this is the case get the value. + if (true_expr->IsInstance()) { + Var var = Downcast(true_expr); + if (let_bound_values_.find(var) != let_bound_values_.end()) { + true_expr = let_bound_values_.Get(var).value(); + } + } + // See through "on_device" calls. - Expr true_expr = IgnoreOnDevice(expr); + true_expr = IgnoreOnDevice(true_expr); + VisitExpr(true_expr); auto it = storage_device_map_.find(true_expr); - ICHECK(it != storage_device_map_.end()); + ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " + << PrettyPrint(true_expr) << " in storage device map"; return it->second; } @@ -258,6 +284,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { std::vector return_ids_; /*! \brief the data types of the return values */ std::vector return_ttypes_; + /*! \brief Maps let var to corresponding value. */ + Map let_bound_values_; }; /*! \brief Code generator for AOT executor */ @@ -335,6 +363,17 @@ class AOTExecutorCodegen : public MixedModeVisitor { */ std::vector PackSid(Expr expr) { std::vector buffer_vars; + + // Var nodes may be let bound, if this is the case get the value. + if (expr->IsInstance()) { + Var var = Downcast(expr); + if (let_bound_values_.find(var) != let_bound_values_.end()) { + expr = let_bound_values_.Get(var).value(); + expr = IgnoreOnDevice(expr); + } + } + ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); StorageInfo& sinfo = storage_device_map_[expr]; // Note that an expression can have multiple sids associated with it @@ -422,9 +461,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { {tir::StringImm(params_by_expr_[arg])}); // NOTE: this cast looks like a no-op, but is required for compilation downstream. // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, - // adding this cast forces the codegen to insert the cast. In this case, a cast is required - // because param_handle is actually code-generated as `const void*`, and the `const` piece - // needs to be removed. + // adding this cast forces the codegen to insert the cast. In this case, a cast is + // required because param_handle is actually code-generated as `const void*`, and the + // `const` piece needs to be removed. args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); } else { auto sids = FindExpr(arg); @@ -599,6 +638,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const CallNode* call_node) override { + OnDeviceProps on_device_props = GetOnDeviceProps(call_node); + if (on_device_props.body.defined()) { + VisitExpr(on_device_props.body); + return; + } + DeviceCopyProps device_copy_props = GetDeviceCopyProps(call_node); CallLoweredProps call_lowered_props = GetCallLoweredProps(call_node); @@ -624,28 +669,32 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const VarNode* op) override { Expr expr = GetRef(op); - StorageInfo& sinfo = storage_device_map_[expr]; + if (storage_device_map_.find(expr) != storage_device_map_.end()) { + StorageInfo& sinfo = storage_device_map_[expr]; - // If the Var node is an output node we need to copy the content of the variable to the output - // It's safe to check the SID here because Var StorageToken are never reallocated - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - if (params_by_expr_.find(expr) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } else { - auto var_expr = FindExpr(expr); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); + // If the Var node is an output node we need to copy the content of the variable to the + // output It's safe to check the SID here because Var StorageToken are never reallocated + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + if (params_by_expr_.find(expr) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); + } else { + auto var_expr = FindExpr(expr); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); + } } } } void VisitExpr_(const ConstantNode* op) override { Expr expr = GetRef(op); + ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) + << "Storage map did not contain constant expr " << PrettyPrint(expr); StorageInfo& sinfo = storage_device_map_[expr]; std::stringstream ss; ss << "constant_" << constant_map_.size(); @@ -674,12 +723,21 @@ class AOTExecutorCodegen : public MixedModeVisitor { } void VisitExpr_(const LetNode* op) override { - // TODO(giuseros): support Let nodes in AOT - LOG(FATAL) << "Let not yet implemented in AOT"; + auto pre_visit = [this](const LetNode* op) { + this->VisitExpr(op->var); + let_bound_values_.Set(op->var, op->value); + this->VisitExpr(op->value); + }; + auto post_visit = [this](const LetNode* op) { + this->VisitExpr(op->body); + this->visit_counter_[op] += 1; + }; + ExpandANormalForm(op, pre_visit, post_visit); } + void VisitExpr_(const TupleGetItemNode* op) override { VisitExpr(op->tuple); } void VisitExpr_(const OpNode* op) override { - if (GetRef(op) != CallLoweredOp()) { + if (GetRef(op) != CallLoweredOp() && GetRef(op) != OnDeviceOp()) { LOG(FATAL) << "All OpNodes except for call_lowered should have been expanded"; } } @@ -775,21 +833,36 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Access IO vars using the buffer vars and + * \brief Access IO vars using the buffer vars and * not the actual var. */ tir::Var GetBufferVarForIO(int index) { return main_buffer_map_[main_signature_[index]]->data; } /*! - * brief Create tir::Var for input/output while updating - * the buffer_maps. + * \brief Create tir::Var for input/output while updating the buffer_maps. + * + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. */ void CreateIOVar(const Expr& expr, const std::string& original_name, bool use_unique_name = true) { - if (expr->IsInstance()) { - Tuple tuple = Downcast(expr); - for (unsigned i = 0; i < tuple->fields.size(); i++) { - CreateIOVar(tuple->fields[i], original_name); + CreateIOVar(expr->checked_type(), original_name, use_unique_name); + } + + /*! + * \brief Create tir::Var for input/output while updating the buffer_maps. + * + * \param expr The expression to evaluate. + * \param original_name The name of the tir::Var. + * \param use_unique_name Whether to generate a new unique name where a name conflicts. + */ + void CreateIOVar(const Type& type, const std::string& original_name, + bool use_unique_name = true) { + if (type->IsInstance()) { + TupleType tuple_type = Downcast(type); + for (unsigned i = 0; i < tuple_type->fields.size(); i++) { + CreateIOVar(tuple_type->fields[i], original_name); } } else { std::string name = original_name; @@ -798,19 +871,20 @@ class AOTExecutorCodegen : public MixedModeVisitor { } tir::Var var = tir::Var(name, DataType::Handle()); main_signature_.push_back(var); - auto tensor_type = expr->checked_type().as(); + auto tensor_type = type.as(); + ICHECK(tensor_type) << "Expected TensorType node but was " << type->GetTypeKey(); DataType elem_type = tensor_type->dtype; tir::Var buffer_var = tir::Var(name + "_buffer_var", PointerType(PrimType(elem_type), "global")); tir::Buffer buffer = tir::Buffer(buffer_var, elem_type, tensor_type->shape, {}, 0, name + "_buffer", 16, 1, tir::BufferType::kDefault); main_buffer_map_.Set(var, buffer); - io_tensor_types_.Set(var, Downcast(expr->checked_type())); + io_tensor_types_.Set(var, Downcast(type)); } } /*! - * brief Create a unique name for I/O Var + * \brief Create a unique name for I/O Var */ std::string GetUniqueIOVarName(std::string name) { if (io_var_names_.find(name) == io_var_names_.end()) { @@ -823,7 +897,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Calculate workspace sizes for PrimFuncs in the IRModule + * \brief Calculate workspace sizes for PrimFuncs in the IRModule */ Map CalculateWorkspaceSizes( const IRModule& lowered_mod, const Map& function_metadata) { @@ -852,7 +926,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Run USMP to plan memory for lowered IRModule + * \brief Run USMP to plan memory for lowered IRModule. */ IRModule PlanMemoryWithUSMP(const IRModule& mod) { VLOG(1) << "Planning memory with USMP for module:" << std::endl << PrettyPrint(mod); @@ -888,7 +962,7 @@ class AOTExecutorCodegen : public MixedModeVisitor { } /*! - * brief Run StorageRewrite to plan memory for lowered IRModule + * \brief Run StorageRewrite to plan memory for lowered IRModule. */ IRModule PlanMemoryWithStorageRewrite(const IRModule& mod) { Executor executor_config = mod->GetAttr(tvm::attr::kExecutor).value(); @@ -966,6 +1040,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector return_sid_; /*! \brief This is per IO var name counter to aid the generating unique names */ std::unordered_map io_var_names_; + /*! \brief Maps let var to corresponding value. */ + Map let_bound_values_; public: AOTExecutorCodegen(runtime::Module* mod, const Array& targets) @@ -1011,6 +1087,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { << ") is not one of the expected values"; } + mod = transform::ToANormalForm()(mod); + IRModule lowered_mod = tec::LowerTEPass( mod_name, [this, workspace_byte_alignment](BaseFunc func) { @@ -1056,9 +1134,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (auto sid : kv.second->storage_ids) { // The buffer_var is created with storage_scope to be global.workspace to be serviced by // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor - // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and - // should not be lowered to the stack. For more details please refer to the discussion here: - // https://github.com/apache/tvm/issues/9022 + // allocates should be serviced by TVMBAWs as the data could be accessed by many devices + // and should not be lowered to the stack. For more details please refer to the discussion + // here: https://github.com/apache/tvm/issues/9022 te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)), "global.workspace")); sids_table_[sid] = buffer_var; @@ -1071,12 +1149,13 @@ class AOTExecutorCodegen : public MixedModeVisitor { // If output tensor names were provided use them if (auto opt = func->GetAttr>("output_tensor_names")) { Array output_tensor_names = opt.value(); - if (lowered_main_func->body->IsInstance()) { - Tuple output_tuple = Downcast(lowered_main_func->body); - for (unsigned i = 0; i < output_tuple->fields.size(); i++) { + Expr output_expr = lowered_main_func->body; + if (output_expr->checked_type()->IsInstance()) { + TupleType output_tuple_type = Downcast(output_expr->checked_type()); + for (unsigned i = 0; i < output_tuple_type->fields.size(); i++) { // AoT Executor Codegen does not create these names, // thus should be used as they are provided. - CreateIOVar(output_tuple->fields[i], output_tensor_names[i], + CreateIOVar(output_tuple_type->fields[i], output_tensor_names[i], /*use_unique_name = */ false); } } else { @@ -1094,8 +1173,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet - // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need - // to run the LegalizePackedCalls pass. + // because the packed calls arguments are not wrapped in TVMValues. To make this happen we + // need to run the LegalizePackedCalls pass. LoweredOutput ret; ret.params = std::unordered_map>(); for (auto param : params_) { diff --git a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc index 722e7c69d9ab..210175817f9c 100644 --- a/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc +++ b/src/relay/backend/contrib/cmsisnn/relay_to_tir.cc @@ -655,19 +655,61 @@ class RelayToTIRVisitor : public MixedModeMutator { return Call(new_global_var, call->args, call->attrs, call->type_args, call->span); } - Expr Rewrite_(const CallNode* pre, const Expr& post) override { - if (const CallNode* call = post.as()) { - auto* func = call->op.as(); - if (func == nullptr) { - return post; + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + // outlineable function no longer needs let binding + if (this->CanOutlineExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + // drop the let binding + if (this->CanOutlineExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } - auto codegen_name = func->GetAttr(attr::kCompiler); - if (codegen_name.defined() && codegen_name == "cmsis-nn") { - const CallNode* inner_call = func->body.as(); + bool CanOutlineExpr(const Expr& expr) { + // TODO(@lhutton1): This behaviour is similar to the OutlineCompilerFunctions pass + // we could reuse this functionality by separating outlining and lowering in this + // pass. + if (!expr->IsInstance()) { + return false; + } + const auto* func = expr.as(); + auto codegen_name = func->GetAttr(attr::kCompiler); + if (!codegen_name.defined() || codegen_name != "cmsis-nn") { + return false; + } + return true; + } + + Expr Rewrite_(const CallNode* pre, const Expr& post) override { + if (const auto* call = post.as()) { + if (CanOutlineExpr(call->op)) { + const auto* func = call->op.as(); + ICHECK(func) << "Expected function node but was " << call->op->GetTypeKey(); + const auto codegen_name = func->GetAttr(attr::kCompiler); auto global_func_name = func->GetAttr(tvm::attr::kGlobalSymbol); GlobalVar new_global_var(global_func_name.value()); + const CallNode* inner_call = func->body.as(); if (!inner_call) { return CallToFuncWithoutCompilerAttr(new_global_var, GetRef(call), GetRef(func)); @@ -684,21 +726,20 @@ class RelayToTIRVisitor : public MixedModeMutator { if (comp_name == "cmsis-nn.qnn_softmax") { EmitSoftMax(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_mul") { + } else if (comp_name == "cmsis-nn.qnn_mul") { EmitMul(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_add") { + } else if (comp_name == "cmsis-nn.qnn_add") { EmitAdd(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_conv2d") { + } else if (comp_name == "cmsis-nn.qnn_conv2d") { EmitConv2D(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_fully_connected") { + } else if (comp_name == "cmsis-nn.qnn_fully_connected") { EmitFullyConnected(new_global_var, composite_func->body); - } - if (comp_name == "cmsis-nn.qnn_avg_pool2d" || comp_name == "cmsis-nn.qnn_max_pool2d") { + } else if (comp_name == "cmsis-nn.qnn_avg_pool2d" || + comp_name == "cmsis-nn.qnn_max_pool2d") { EmitPool2D(new_global_var, composite_func->body, comp_name.value()); + } else { + return CallToFuncWithoutCompilerAttr(new_global_var, GetRef(call), + GetRef(func)); } Array args; @@ -709,7 +750,6 @@ class RelayToTIRVisitor : public MixedModeMutator { return Call(new_global_var, args, call->attrs, call->type_args, call->span); } } - return post; } diff --git a/src/relay/backend/contrib/ethosu/codegen.cc b/src/relay/backend/contrib/ethosu/codegen.cc index dfcf54f7b76c..47c80b47c579 100644 --- a/src/relay/backend/contrib/ethosu/codegen.cc +++ b/src/relay/backend/contrib/ethosu/codegen.cc @@ -57,28 +57,81 @@ class OutlineCompilerFunctionsMutator : public MixedModeMutator { explicit OutlineCompilerFunctionsMutator(const IRModule& mod, const std::string& compiler_name) : mod_(mod), compiler_name_(compiler_name) {} + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + // Outlineable function no longer needs let binding + if (this->CanOutlineExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + // Drop the let binding + if (this->CanOutlineExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) override { Call call = Downcast(post); - if (call->op->IsInstance()) { + if (CanOutlineExpr(call->op)) { Function func = Downcast(call->op); - auto compiler = func->GetAttr(attr::kCompiler); - if (compiler.defined() && compiler == compiler_name_) { - auto gv_name = func->GetAttr("global_symbol").value_or(""); - ICHECK_NE(gv_name, "") - << "Function to be outlined must have global_symbol attribute, but didn't."; - GlobalVar gv(gv_name); - if (func->checked_type_.defined()) { - gv->checked_type_ = func->checked_type(); - } - mod_->Update(gv, func); - return Call(gv, call->args, call->attrs, call->type_args); + auto gv_name = func->GetAttr("global_symbol").value_or(""); + ICHECK_NE(gv_name, "") + << "Function to be outlined must have global_symbol attribute, but didn't."; + GlobalVar gv(gv_name); + if (func->checked_type_.defined()) { + gv->checked_type_ = func->checked_type(); } + mod_->Update(gv, func); + return Call(gv, call->args, call->attrs, call->type_args); } return post; } private: + /*! + * \brief Check if the expr is a function and has the same + * compiler name as compiler_name_. + * + * \param expr The input expr. + * \return True if is outlineable else False. + */ + bool CanOutlineExpr(const Expr& expr) { + if (!expr->IsInstance()) { + return false; + } + Function func = Downcast(expr); + auto compiler = func->GetAttr(attr::kCompiler); + if (!compiler.defined()) { + return false; + } + if (compiler != compiler_name_) { + return false; + } + return true; + } + + /*! \brief The module that the pass will run on. */ IRModule mod_; + /*! \brief The name of the compiler to enable outlining on external functions for. */ std::string compiler_name_; }; @@ -188,7 +241,7 @@ class RemoveRedundantIdentities : public MixedModeMutator { const auto* call_tt = call->checked_type_.as(); const auto* identity_arg_tt = identity_arg->checked_type_.as(); - CHECK(call_tt && identity_arg_tt) + ICHECK(call_tt && identity_arg_tt) << "InferType should be run before RemoveRedundantIdentities"; // we can only remove the identity operation if the second non-compute operation diff --git a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc index 86f55caf9342..c498baa6d11d 100644 --- a/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc +++ b/src/relay/backend/contrib/example_target_hooks/relay_to_tir.cc @@ -94,23 +94,67 @@ class ConvertAddToSubtract : public MixedModeMutator { ir_module_->Add(new_global_var, replacement_func); } + Expr VisitExpr_(const LetNode* op) final { + auto pre_visit = [this](const LetNode* op) { + Expr var = this->VisitExpr(op->var); + Expr value = this->VisitExpr(op->value); + + // Outlineable function no longer needs let binding + if (this->CanLowerExpr(value)) { + this->memo_[var] = value; + } + }; + auto post_visit = [this](const LetNode* op) { + // Rely on the Memoizer to cache pre-visit values + Expr value = this->VisitExpr(op->value); + Expr body = this->VisitExpr(op->body); + auto expr = GetRef(op); + + // Drop the let binding + if (this->CanLowerExpr(value)) { + this->memo_[expr] = this->VisitExpr(op->body); + } else { + Var var = Downcast(this->VisitExpr(op->var)); + if (var.same_as(op->var) && value.same_as(op->value) && body.same_as(op->body)) { + this->memo_[expr] = expr; + } else { + this->memo_[expr] = Let(var, value, body); + } + } + }; + ExpandANormalForm(op, pre_visit, post_visit); + return memo_[GetRef(op)]; + } + + bool CanLowerExpr(const Expr& expr) { + const auto* func = expr.as(); + if (func == nullptr) { + return false; + } + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); + if (!func_name.defined()) { + return false; + } + if (func_name != "replace_add_with_subtract") { + return false; + } + return true; + } + Expr Rewrite_(const CallNode* pre, const Expr& post) override { if (const CallNode* call = post.as()) { - auto* func = call->op.as(); - if (func == nullptr) { - return post; - } + if (CanLowerExpr(call->op)) { + auto* func = call->op.as(); + auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - auto func_name = func->GetAttr(::tvm::attr::kGlobalSymbol); - if (func_name.defined() && func_name == "replace_add_with_subtract") { // Introduce a new global var to map the function to and copy the source type // over for InferType GlobalVar new_global_var(func_name.value()); new_global_var->checked_type_ = func->checked_type(); ReplaceAddWithSubtractPrimFunc(new_global_var, GetRef(func)); - // Since we are replacing the Relay function with a call to a TIR function, we must use the - // call_lowered op. + // Since we are replacing the Relay function with a call to a TIR function, we must use + // the call_lowered op. CallLoweredAttrs attrs; attrs.metadata.Set("relay_attrs", call->attrs); ICHECK(call->type_args.empty()) << "lowered functions cannot be polymorphic"; diff --git a/src/relay/backend/te_compiler.cc b/src/relay/backend/te_compiler.cc index 70d74ea92377..71b57aed81f6 100644 --- a/src/relay/backend/te_compiler.cc +++ b/src/relay/backend/te_compiler.cc @@ -678,6 +678,8 @@ class LowerTensorExprMutator : public DeviceAwareExprMutator { if (prim_func.defined()) { // Leaving let var scope primitive_functions_.erase(pre_let_node->var.get()); + // Drop the let node + return post_let_node->body; } return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); } diff --git a/src/relay/transforms/target_hooks.cc b/src/relay/transforms/target_hooks.cc index b0ac883623d2..0022baf881ba 100644 --- a/src/relay/transforms/target_hooks.cc +++ b/src/relay/transforms/target_hooks.cc @@ -61,25 +61,19 @@ class TargetHookVisitor : public tvm::relay::MixedModeVisitor { ExpandANormalForm(op, pre_visit, post_visit); } - void VisitExpr_(const CallNode* call) override { - // Descend the call tree - for (auto arg : call->args) { - VisitExpr(arg); + void VisitExpr_(const FunctionNode* func) override { + ExprVisitor::VisitExpr_(func); + if (!func->GetAttr(attr::kCompiler).defined()) { + return; } - - if (const FunctionNode* func = call->op.as()) { - if (!func->GetAttr(attr::kCompiler).defined()) { - return; - } - String code_gen_name = func->GetAttr(attr::kCompiler).value(); - Optional target_kind = tvm::TargetKind::Get(code_gen_name); - if (!target_kind || !target_attr_map_.count(target_kind.value())) { - return; - } - Pass custom_target_pass = target_attr_map_[target_kind.value()]; - if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { - pass_list_.push_back(custom_target_pass); - } + String code_gen_name = func->GetAttr(attr::kCompiler).value(); + Optional target_kind = tvm::TargetKind::Get(code_gen_name); + if (!target_kind || !target_attr_map_.count(target_kind.value())) { + return; + } + Pass custom_target_pass = target_attr_map_[target_kind.value()]; + if (std::find(pass_list_.begin(), pass_list_.end(), custom_target_pass) == pass_list_.end()) { + pass_list_.push_back(custom_target_pass); } } }; From 7d2fd9fe4b86477d2fb93a8a6a1596ca0e43b8d2 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 3 May 2022 15:44:34 +0000 Subject: [PATCH 2/4] Maintain GetStorage(var) == GetStorage(value) invariant for lets Change-Id: Id40b70f67a3e37f75b8331aa89f1819072e4d48e --- src/relay/backend/aot_executor_codegen.cc | 95 ++++++++++------------- 1 file changed, 40 insertions(+), 55 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 45c934a828c9..d222ff1ad91c 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -126,14 +126,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (const auto& param : func_node->params) { CreateStorage(param.get()); } - StorageInfo si = GetStorage(func_node->body); - - // If the final expr could not be found it means it was let bound, - // manually add the var to the storage device map so it can be - // found by `UpdateMainWorkspaceSize`. - if (storage_device_map_.find(func_node->body) == storage_device_map_.end()) { - storage_device_map_[func_node->body] = si; - } + GetStorage(func_node->body); } void VisitExpr_(const GlobalVarNode* op) final { @@ -176,7 +169,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { void PreVisitLetBinding_(const Var& var, const Expr& value) final { VisitExpr(value); - let_bound_values_.Set(var, value); + StorageInfo si = GetStorage(value); + storage_device_map_[var] = si; } private: @@ -231,14 +225,6 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { true_expr = let_node->body; } - // Var nodes may be let bound, if this is the case get the value. - if (true_expr->IsInstance()) { - Var var = Downcast(true_expr); - if (let_bound_values_.find(var) != let_bound_values_.end()) { - true_expr = let_bound_values_.Get(var).value(); - } - } - // See through "on_device" calls. true_expr = IgnoreOnDevice(true_expr); @@ -284,8 +270,6 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { std::vector return_ids_; /*! \brief the data types of the return values */ std::vector return_ttypes_; - /*! \brief Maps let var to corresponding value. */ - Map let_bound_values_; }; /*! \brief Code generator for AOT executor */ @@ -364,14 +348,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector PackSid(Expr expr) { std::vector buffer_vars; - // Var nodes may be let bound, if this is the case get the value. - if (expr->IsInstance()) { - Var var = Downcast(expr); - if (let_bound_values_.find(var) != let_bound_values_.end()) { - expr = let_bound_values_.Get(var).value(); - expr = IgnoreOnDevice(expr); - } - } ICHECK(storage_device_map_.find(expr) != storage_device_map_.end()) << "Storage map did not contain constant expr " << PrettyPrint(expr); StorageInfo& sinfo = storage_device_map_[expr]; @@ -461,9 +437,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { {tir::StringImm(params_by_expr_[arg])}); // NOTE: this cast looks like a no-op, but is required for compilation downstream. // Because DataType::Handle has default bits=64, but CodeGenC does not observe this field, - // adding this cast forces the codegen to insert the cast. In this case, a cast is - // required because param_handle is actually code-generated as `const void*`, and the - // `const` piece needs to be removed. + // adding this cast forces the codegen to insert the cast. In this case, a cast is required + // because param_handle is actually code-generated as `const void*`, and the `const` piece + // needs to be removed. args.push_back(tvm::tir::Cast(DataType::Handle(32, 1), param_handle)); } else { auto sids = FindExpr(arg); @@ -669,24 +645,27 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const VarNode* op) override { Expr expr = GetRef(op); - if (storage_device_map_.find(expr) != storage_device_map_.end()) { - StorageInfo& sinfo = storage_device_map_[expr]; + StorageInfo& sinfo = storage_device_map_[expr]; - // If the Var node is an output node we need to copy the content of the variable to the - // output It's safe to check the SID here because Var StorageToken are never reallocated - auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); - if (output_iter != return_sid_.end()) { - int output_index = std::distance(return_sid_.begin(), output_iter); - if (params_by_expr_.find(expr) != params_by_expr_.end()) { - auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), - {tir::StringImm(params_by_expr_[expr])}); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } else { - auto var_expr = FindExpr(expr); - CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], - /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); - } + // Let bound vars refer to a value, so these should not be considered "output" vars. + if (let_bound_vars_.find(GetRef(op)) != let_bound_vars_.end()) { + return; + } + + // If the Var node is an output node we need to copy the content of the variable to the output + // It's safe to check the SID here because Var StorageToken are never reallocated + auto output_iter = std::find(return_sid_.begin(), return_sid_.end(), sinfo->storage_ids[0]); + if (output_iter != return_sid_.end()) { + int output_index = std::distance(return_sid_.begin(), output_iter); + if (params_by_expr_.find(expr) != params_by_expr_.end()) { + auto param_handle = tvm::tir::Call(DataType::Handle(), tvm::tir::builtin::lookup_param(), + {tir::StringImm(params_by_expr_[expr])}); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), param_handle, + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); + } else { + auto var_expr = FindExpr(expr); + CopyToOutput(GetBufferVarForIO(input_vars_.size() + output_index), var_expr[0], + /*pack_input*/ false, sinfo->storage_sizes_in_bytes[0]); } } } @@ -724,8 +703,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const LetNode* op) override { auto pre_visit = [this](const LetNode* op) { + let_bound_vars_.insert(op->var); this->VisitExpr(op->var); - let_bound_values_.Set(op->var, op->value); this->VisitExpr(op->value); }; auto post_visit = [this](const LetNode* op) { @@ -789,6 +768,12 @@ class AOTExecutorCodegen : public MixedModeVisitor { continue; } + // Make sure it hasn't already been allocated, this can happen + // with let-bound var/value pairs. + if (allocated.find(sid) != allocated.end()) { + continue; + } + allocated[sid] = constant_map_.count(sids_table_[sid]); // TODO(giuseros): we should allocate this once outside the PrimFunc @@ -1040,8 +1025,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { std::vector return_sid_; /*! \brief This is per IO var name counter to aid the generating unique names */ std::unordered_map io_var_names_; - /*! \brief Maps let var to corresponding value. */ - Map let_bound_values_; + /*! \brief A set of variables that are let bound. */ + std::unordered_set let_bound_vars_; public: AOTExecutorCodegen(runtime::Module* mod, const Array& targets) @@ -1134,9 +1119,9 @@ class AOTExecutorCodegen : public MixedModeVisitor { for (auto sid : kv.second->storage_ids) { // The buffer_var is created with storage_scope to be global.workspace to be serviced by // TVMBackendAllocWorkspace(TVMBAW) calls, explicitly. The reasoning being the executor - // allocates should be serviced by TVMBAWs as the data could be accessed by many devices - // and should not be lowered to the stack. For more details please refer to the discussion - // here: https://github.com/apache/tvm/issues/9022 + // allocates should be serviced by TVMBAWs as the data could be accessed by many devices and + // should not be lowered to the stack. For more details please refer to the discussion here: + // https://github.com/apache/tvm/issues/9022 te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)), "global.workspace")); sids_table_[sid] = buffer_var; @@ -1173,8 +1158,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { VisitExpr(lowered_main_func->body); // Create the runner function. Please note that the function is not legal yet - // because the packed calls arguments are not wrapped in TVMValues. To make this happen we - // need to run the LegalizePackedCalls pass. + // because the packed calls arguments are not wrapped in TVMValues. To make this happen we need + // to run the LegalizePackedCalls pass. LoweredOutput ret; ret.params = std::unordered_map>(); for (auto param : params_) { From 02279c26ab534ec0fcb4b8b754e1347b7883e78a Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Tue, 3 May 2022 16:44:31 +0000 Subject: [PATCH 3/4] Add check to ensure ANF runs in AOT Change-Id: I8de2bd19c7c17057e2bc89f6a68595780c2e9433 --- tests/python/relay/aot/test_crt_aot.py | 47 ++++++++++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tests/python/relay/aot/test_crt_aot.py b/tests/python/relay/aot/test_crt_aot.py index 3c44d2bf1bc8..2991cc01fc92 100644 --- a/tests/python/relay/aot/test_crt_aot.py +++ b/tests/python/relay/aot/test_crt_aot.py @@ -36,6 +36,7 @@ from tvm.relay.backend import Executor, Runtime from tvm.micro import model_library_format as mlf from tvm.micro import export_model_library_format +from tvm.ir.instrument import pass_instrument from aot_test_utils import ( AOTTestModel, AOT_DEFAULT_RUNNER, @@ -1027,5 +1028,51 @@ def test_aot_codegen_checks_returns(): ) +def test_aot_uses_anf(): + """Checks that A-Normal Form is being used in the AOT lowering pipeline.""" + x = relay.var("x", shape=(1, 10, 10, 10)) + y = relay.var("y", shape=(1, 10, 10, 10)) + z = relay.add(x, y) + func = relay.Function([x, y], z) + + @pass_instrument + class CheckANFRuns: + def __init__(self): + self.did_run_anf = False + + def run_before_pass(self, _, info): + if info.name == "ToANormalForm": + self.did_run_anf = True + if info.name == "LowerTE": + assert self.did_run_anf, "ToANormalForm pass should run before LowerTE." + + check_run_anf = CheckANFRuns() + + model = AOTTestModel(module=IRModule.from_expr(func), inputs=None, outputs=None) + runtime = Runtime("crt") + executor = Executor( + "aot", + { + "workspace-byte-alignment": 8, + "interface-api": "c", + "unpacked-api": True, + }, + ) + config = {"tir.disable_vectorize": True} + + with tvm.transform.PassContext(opt_level=3, config=config, instruments=[check_run_anf]): + tvm.relay.build( + model.module, + tvm.target.Target("c"), + executor=executor, + runtime=runtime, + workspace_memory_pools=None, + params=model.params, + mod_name=model.name, + ) + + assert check_run_anf.did_run_anf, "Expected ToANormalForm pass to have run." + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 7a57baac5905205cc042ca4f152ecdb4dfea6f9d Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 5 May 2022 16:12:42 +0000 Subject: [PATCH 4/4] Avoid let block traversal and don't visit var in let visitation Change-Id: I74c080e2a09e84a75400db5c3395d508697d5d0f --- src/relay/backend/aot_executor_codegen.cc | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index d222ff1ad91c..60f108aacf66 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -107,7 +107,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { VisitExpr(func); CreateStorage(call_node); for (const Expr& arg : args) { - GetStorage(arg); + VisitExpr(arg); } AssignReturnSid(GetRef(call_node)); } @@ -126,7 +126,7 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { for (const auto& param : func_node->params) { CreateStorage(param.get()); } - GetStorage(func_node->body); + VisitExpr(func_node->body); } void VisitExpr_(const GlobalVarNode* op) final { @@ -217,17 +217,8 @@ class AOTOnDemandAllocator : public transform::DeviceAwareExprVisitor { * \return The corresponding token. */ StorageInfo GetStorage(const Expr& expr) { - Expr true_expr = expr; - - // Don't get storage for let nodes. - while (const auto* let_node = true_expr.as()) { - VisitExpr(true_expr); - true_expr = let_node->body; - } - // See through "on_device" calls. - true_expr = IgnoreOnDevice(true_expr); - + Expr true_expr = IgnoreOnDevice(expr); VisitExpr(true_expr); auto it = storage_device_map_.find(true_expr); ICHECK(it != storage_device_map_.end()) << "Could not find " << true_expr->GetTypeKey() << " " @@ -704,7 +695,6 @@ class AOTExecutorCodegen : public MixedModeVisitor { void VisitExpr_(const LetNode* op) override { auto pre_visit = [this](const LetNode* op) { let_bound_vars_.insert(op->var); - this->VisitExpr(op->var); this->VisitExpr(op->value); }; auto post_visit = [this](const LetNode* op) {