From 77604f42a0381f054242a577a1ded3c7966f4aef Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Wed, 4 May 2022 09:22:24 +0000 Subject: [PATCH 1/5] [AOT] Calculate used memory at the callsite of primitive functions Introduces a new pass in the AOT executor called "AnnotateUsedMemory" which applies liveness analysis to the callsite of each primitive function in order to calculate the total size of the live tensors at this point of execution. The result is provided as a function annotation called "used_memory", which can be consumed by later stages of the compiler (e.g. external codegens) to provide more information about the current memory consumption. This can be useful for some optimizations. Change-Id: I8d6b7447498f19260358bbefe34029ddd86b9c89 --- src/relay/backend/aot/annotate_used_memory.cc | 156 ++++++++ src/relay/backend/aot_executor_codegen.cc | 3 + src/relay/backend/manifest_lifetimes.cc | 367 ++++++++++++++++++ ...fest_lifetimes.cc => manifest_lifetimes.h} | 355 ++--------------- .../relay/aot/test_used_memory_annotator.py | 194 +++++++++ 5 files changed, 757 insertions(+), 318 deletions(-) create mode 100644 src/relay/backend/aot/annotate_used_memory.cc create mode 100644 src/relay/backend/manifest_lifetimes.cc rename src/relay/backend/{vm/manifest_lifetimes.cc => manifest_lifetimes.h} (51%) create mode 100644 tests/python/relay/aot/test_used_memory_annotator.py diff --git a/src/relay/backend/aot/annotate_used_memory.cc b/src/relay/backend/aot/annotate_used_memory.cc new file mode 100644 index 000000000000..25d1624635e0 --- /dev/null +++ b/src/relay/backend/aot/annotate_used_memory.cc @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/aot/memory_pressure.cc + * \brief Analyses the memory pressure at external function callsites. + */ + +#include +#include + +#include "../../transforms/device_aware_visitors.h" +#include "../manifest_lifetimes.h" + +namespace tvm { +namespace relay { +namespace backend { +namespace aot { + +/*! + * \brief Annotates the memory usage of each primitive function by analysing the liveness + * of the input/output tensors at the function callsite and calculating the total amount of + * memory these tensors require. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Get the memory required for a primitive Relay function by calculating the total + * bytes of the live tensors at the callsite of the function. + * + * \param live_tensors The tensors that are live when the function is called. + * \return int The total number of bytes a function requires. + */ + int GetMemoryUsage(const transform::VarSet& live_tensors) { + Array types_stack = {}; + int memory_usage = 0; + + for (const Var& var : live_tensors) { + Type var_type = var->checked_type(); + ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass."; + types_stack.push_back(var_type); + } + + while (!types_stack.empty()) { + Type current_type = types_stack.back(); + types_stack.pop_back(); + + if (const auto* tt_node = current_type.as()) { + for (const Type& type : tt_node->fields) { + types_stack.push_back(type); + } + continue; + } else if (const auto* ft_node = current_type.as()) { + types_stack.push_back(ft_node->ret_type); + continue; + } + + const auto* tt_node = current_type.as(); + ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey(); + int total_tensor_bytes = GetTensorBytes(tt_node); + memory_usage += total_tensor_bytes; + } + return memory_usage; + } + + /*! + * \brief Get the number of bytes a tensor requires. + * + * \param tensor_type_node The checked type of the tensor. + * \return int The number of bytes required. + */ + int GetTensorBytes(const TensorTypeNode* tensor_type_node) { + PrimExpr size = tensor_type_node->Size(); + const auto* size_int_imm = size.as(); + ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was " + << size->GetTypeKey(); + + int total_size = size_int_imm->value; + int dtype_bytes = tensor_type_node->dtype.bytes(); + return total_size * dtype_bytes; + } + + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { + if (const auto* func_node = pre_let_node->value.as()) { + const auto let_bound_values = control_flow_graph_.let_map; + const transform::ControlFlowGraph::NodePtr cfg_node = + let_bound_values.at(GetRef(pre_let_node)); + const transform::VarSet& liveness_out = liveness_.live_out.at(cfg_node); + int memory_pressure = GetMemoryUsage(liveness_out); + Function new_func = WithAttr(std::move(GetRef(func_node)), "used_memory", + tvm::Integer(memory_pressure)); + return Let(post_let_node->var, new_func, post_let_node->body, post_let_node->span); + } + return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); + } + + private: + /*! \brief Control flow graph representation of the main function. */ + transform::ControlFlowGraph control_flow_graph_; + /*! \brief Liveness analysis of the main function. */ + transform::LivenessAnalysis liveness_; +}; + +} // namespace aot +} // namespace backend + +namespace transform { +Pass AnnotateUsedMemory() { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + GlobalVar gv = mod->GetGlobalVar("main"); + Function main_func = Downcast(mod->Lookup("main")); + + // Perform liveness analysis to determine what tensors are 'live' at each functions + // callsite. + support::Arena arena; + ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + + auto new_main_body = + backend::aot::AnnotateUsedMemoryMutator(mod, cfg, lva).VisitExpr(main_func->body); + if (!new_main_body.same_as(main_func->body)) { + Function new_main_func = WithFields(main_func, main_func->params, new_main_body); + mod->Update(gv, new_main_func); + } + return mod; + }; + return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 5938417128e0..0e61cf4b21c6 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -48,6 +48,7 @@ #include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" +#include "./aot/annotate_used_memory.cc" #include "./name_transforms.h" #include "./te_compiler.h" #include "./utils.h" @@ -1079,6 +1080,8 @@ class AOTExecutorCodegen : public MixedModeVisitor { } mod = transform::ToANormalForm()(mod); + mod = transform::InferType()(mod); + mod = transform::AnnotateUsedMemory()(mod); IRModule lowered_mod = tec::LowerTE(mod_name, config_, [this, workspace_byte_alignment](BaseFunc func) { diff --git a/src/relay/backend/manifest_lifetimes.cc b/src/relay/backend/manifest_lifetimes.cc new file mode 100644 index 000000000000..6114cf97a8de --- /dev/null +++ b/src/relay/backend/manifest_lifetimes.cc @@ -0,0 +1,367 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/manifest_lifetimes.cc + * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in + * ANF and post-memory-lowering (explicit manifestation of allocations). + */ + +#include "manifest_lifetimes.h" + +#include +#include +#include +#include + +namespace tvm { +namespace relay { +namespace transform { + +using support::Arena; +using VarSet = std::unordered_set; + +ControlFlowGraph ControlFlowGraph::Create(Arena* arena, const Expr& body) { + return Creator().Create(arena, body); +} + +ControlFlowGraph ControlFlowGraph::Creator::Create(Arena* arena, const Expr& body) { + arena_ = arena; + cfg_.entry = BasicBlock::Make(arena); + VisitExpr(body, cfg_.entry); + return std::move(cfg_); +} + +void ControlFlowGraph::Creator::Succ(BasicBlockPtr from, BasicBlockPtr to) { + from->succ.push_back(to); + to->pred.push_back(from); +} + +void ControlFlowGraph::Creator::VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) { + ICHECK(!in_func_) << "nested functions not supported by CFG analysis"; + in_func_ = true; + + // Unwrap the nested function and proceed normally. + if (f->HasNonzeroAttr(attr::kClosure)) { + ICHECK(f->body.as()); + return VisitExpr(Downcast(f->body)->body, parent); + } + + return VisitExpr(f->body, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) { + Expr expr = GetRef(let_node); + + while (const LetNode* inner_let_node = expr.as()) { + NodePtr curr_node = Node::Make(arena_, parent, expr); + + ICHECK(!cfg_.let_map.count(expr)); + cfg_.let_map[expr] = curr_node; + cfg_.reverse_post_order.push_back(curr_node); + + // The basic block ends upon reaching control flow, with successor blocks corresponding to the + // control flow branch exprs (true/false in If, and one for each clause in Match). + if (const IfNode* ite = AsIgnoringOnDevice(inner_let_node->value)) { + // Create the basic blocks for each branch and mark them as successors to the current block. + BasicBlockPtr t_block = BasicBlock::Make(arena_); + BasicBlockPtr f_block = BasicBlock::Make(arena_); + Succ(parent, t_block); + Succ(parent, f_block); + + VisitExpr(ite->true_branch, t_block); + VisitExpr(ite->false_branch, f_block); + + // All subsequent bindings (and/or the body expr) will be in a new basic block. + BasicBlockPtr next = BasicBlock::Make(arena_); + Succ(t_block, next); + Succ(f_block, next); + parent = next; + } else if (const MatchNode* match = AsIgnoringOnDevice(inner_let_node->value)) { + // Same as above but one for each pattern. + std::vector clause_blocks; + BasicBlockPtr next = BasicBlock::Make(arena_); + for (const Clause& clause : match->clauses) { + BasicBlockPtr clause_block = BasicBlock::Make(arena_); + Succ(parent, clause_block); + Succ(clause_block, next); + VisitExpr(clause->rhs, clause_block); + } + parent = next; + } + + expr = inner_let_node->body; + } + + VisitExpr(expr, parent); +} + +void ControlFlowGraph::Creator::VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { + // TODO(@altanh): is there a way of making this work? + LOG(FATAL) << "If expressions should be bound to variables."; +} + +void ControlFlowGraph::Creator::VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { + // TODO(@altanh): same as If + LOG(FATAL) << "Match expressions should be bound to variables."; +} + +VarSet VarUseCollector::VisitExpr_(const VarNode* var_node) { return {GetRef(var_node)}; } + +VarSet VarUseCollector::VisitExpr_(const CallNode* call_node) { + VarSet use = VisitExpr(call_node->op); + for (const Expr& arg : call_node->args) { + VarSet arg_use = VisitExpr(arg); + use.insert(arg_use.begin(), arg_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleNode* tuple_node) { + VarSet use; + for (const Expr& field : tuple_node->fields) { + VarSet field_use = VisitExpr(field); + use.insert(field_use.begin(), field_use.end()); + } + return use; +} + +VarSet VarUseCollector::VisitExpr_(const TupleGetItemNode* get_node) { + return VisitExpr(get_node->tuple); +} + +VarSet VarUseCollector::VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } + +VarSet VarUseCollector::VisitExpr_(const MatchNode* match_node) { + return VisitExpr(match_node->data); +} + +UseDefAnalysis UseDefAnalysis::Analyze(const CFG& cfg) { + UseDefAnalysis a; + + // One pass is sufficient. + for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { + const CFG::NodePtr& node = *it; + if (const LetNode* let_node = AsIgnoringOnDevice(node->expr)) { + a.use[node] = a.use_collector.VisitExpr(let_node->value); + a.def[node] = let_node->var; + } else { + a.use[node] = a.use_collector.VisitExpr(node->expr); + a.def[node] = Var(); + } + } + + return a; +} + +bool SetEqual(const VarSet& a, const VarSet& b) { + if (a.size() != b.size()) { + return false; + } + for (auto& xa : a) { + if (!b.count(xa)) { + return false; + } + } + return true; +} + +LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg, + const UseDefAnalysis& use_def) { + LivenessAnalysis a; + std::list worklist; + + // Initialize worklist to post-order traversal for quick convergence. + worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); + + // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. + auto visitor = [&](const CFG::NodePtr n) { + VarSet old_in_n = a.live_in[n]; + VarSet old_out_n = a.live_out[n]; + + a.live_in[n] = use_def.use.at(n); + for (const Var& v : a.live_out[n]) { + if (!v.same_as(use_def.def.at(n))) { + a.live_in[n].insert(v); + } + } + + a.live_out[n] = VarSet(); + for (const CFG::NodePtr& s : n->GetSucc()) { + a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); + } + + if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { + // No need to update the worklist. + } else { + // Add predecessor nodes back to worklist (no need to add successors, since each node's + // in/out sets are not dependent on its predecessors). + for (const CFG::NodePtr& p : n->GetPred()) { + worklist.push_back(p); + } + } + }; + + while (!worklist.empty()) { + const CFG::NodePtr n = worklist.front(); + worklist.pop_front(); + visitor(n); + } + + return a; +} + +Expr KillInserter::VisitExpr_(const LetNode* let_node) { + Expr expr = GetRef(let_node); + LetList ll; + + while (const LetNode* inner_let_node = expr.as()) { + ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); + + ICHECK(!inner_let_node->value.as()) << "aliasing should have been eliminated."; + ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; + + const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); + + const VarSet& li = lva_->live_in.at(n); + const VarSet& lo = lva_->live_out.at(n); + + // Killed vars = live in - live out. + VarSet kills; + for (const Var& v : li) { + if (!lo.count(v)) { + kills.insert(v); + } + } + + for (const Var& v : kills) { + ll.Push(Call(Op::Get("memory.kill"), {v})); + } + + expr = inner_let_node->body; + } + + return ll.Get(VisitExpr(expr)); +} + +Expr AliasEliminator::VisitExpr_(const LetNode* let_node) { + Expr expr = GetRef(let_node); + LetList ll; + std::vector aliased_vars; + + while (const LetNode* inner_let_node = expr.as()) { + const Var& var = inner_let_node->var; + const Expr& val = inner_let_node->value; + bool aliased = false; + ICHECK(!alias_.count(var)); + + if (const VarNode* alias_of_n = AsIgnoringOnDevice(val)) { + alias_[var] = Downcast(VisitExpr_(alias_of_n)); + aliased = true; + } else if (AsIgnoringOnDevice(val)) { + // Copying to the same device is aliasing. + // WARNING: this must be kept in sync with the VM compiler logic in + // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). + Expr unwrapped = IgnoreOnDevice(val); + DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); + if (copy_props.body.defined()) { + if (copy_props.src_virtual_device->device_type() == + copy_props.dst_virtual_device->device_type() && + copy_props.src_virtual_device->virtual_device_id == + copy_props.dst_virtual_device->virtual_device_id) { + Expr to_copy = Downcast(unwrapped)->args[0]; + if (const VarNode* alias_of_n = to_copy.as()) { + alias_[var] = Downcast(VisitExpr_(alias_of_n)); + aliased = true; + } + } + } + } + + if (!aliased) { + ll.Push(var, VisitExpr(val)); + } else { + aliased_vars.push_back(var); + } + + expr = inner_let_node->body; + } + + Expr body = ll.Get(VisitExpr(expr)); + + // remove the aliased vars so that alias_ only tracks things in scope + for (const Var& v : aliased_vars) { + alias_.erase(v); + } + + return body; +} + +Expr AliasEliminator::VisitExpr_(const VarNode* var_node) { + Var var = GetRef(var_node); + if (alias_.count(var)) { + return alias_[var]; + } + return var; +} + +Expr AliasEliminator::VisitExpr_(const FunctionNode* func_node) { + Expr new_body = VisitExpr(func_node->body); + return WithFields(GetRef(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body); +} + +Expr AliasEliminator::VisitExpr_(const MatchNode* match_node) { + if (const VarNode* data_var_node = AsIgnoringOnDevice(match_node->data)) { + Var data_var = Downcast(VisitExpr_(data_var_node)); + std::vector new_clauses; + for (const Clause& clause : match_node->clauses) { + const PatternVarNode* pv_node = nullptr; + if ((pv_node = clause->lhs.as())) { + alias_[pv_node->var] = data_var; + } + new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); + if (pv_node) { + alias_.erase(pv_node->var); + } + } + return Match(data_var, new_clauses, match_node->complete, match_node->span); + } else { + return ExprMutator::VisitExpr_(match_node); + } +} + +Pass ManifestLifetimes() { + auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { + f = Downcast(AliasEliminator().Mutate(f)); + Arena arena; + ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + KillInserter ki(&cfg, &lva); + Function nf = Downcast(ki.Mutate(f)); + return nf; + }; + return CreateFunctionPass(pass_func, 0, "ManifestLifetimes", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes").set_body_typed(ManifestLifetimes); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/vm/manifest_lifetimes.cc b/src/relay/backend/manifest_lifetimes.h similarity index 51% rename from src/relay/backend/vm/manifest_lifetimes.cc rename to src/relay/backend/manifest_lifetimes.h index 3ba129702b52..5826fcf1ce65 100644 --- a/src/relay/backend/vm/manifest_lifetimes.cc +++ b/src/relay/backend/manifest_lifetimes.h @@ -18,17 +18,24 @@ */ /*! - * \file src/relay/backend/vm/manifest_lifetimes.cc + * \file src/relay/backend/manifest_lifetimes.h * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in * ANF and post-memory-lowering (explicit manifestation of allocations). */ +#ifndef TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ +#define TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ + #include -#include "../../../support/arena.h" -#include "../../op/memory/device_copy.h" -#include "../../transforms/device_aware_visitors.h" -#include "../../transforms/let_list.h" +#include +#include +#include + +#include "../../support/arena.h" +#include "../op/memory/device_copy.h" +#include "../transforms/device_aware_visitors.h" +#include "../transforms/let_list.h" namespace tvm { namespace relay { @@ -71,7 +78,7 @@ class ControlFlowGraph { // The successor basic blocks. std::vector succ; - static BasicBlockPtr Make(Arena* arena) { return arena->make(); } + static BasicBlockPtr Make(support::Arena* arena) { return arena->make(); } }; /*! @@ -154,12 +161,7 @@ class ControlFlowGraph::Creator : private ExprFunctorsucc.push_back(to); - to->pred.push_back(from); - } + void Succ(BasicBlockPtr from, BasicBlockPtr to); #define DEFAULT_CFG(OP) \ void VisitExpr_(const OP* op, BasicBlockPtr parent) final { \ @@ -187,74 +186,10 @@ class ControlFlowGraph::Creator : private ExprFunctorHasNonzeroAttr(attr::kClosure)) { - ICHECK(f->body.as()); - return VisitExpr(Downcast(f->body)->body, parent); - } - - return VisitExpr(f->body, parent); - } - - void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final { - Expr expr = GetRef(let_node); - - while (const LetNode* inner_let_node = expr.as()) { - NodePtr curr_node = Node::Make(arena_, parent, expr); - - ICHECK(!cfg_.let_map.count(expr)); - cfg_.let_map[expr] = curr_node; - cfg_.reverse_post_order.push_back(curr_node); - - // The basic block ends upon reaching control flow, with successor blocks corresponding to the - // control flow branch exprs (true/false in If, and one for each clause in Match). - if (const IfNode* ite = AsIgnoringOnDevice(inner_let_node->value)) { - // Create the basic blocks for each branch and mark them as successors to the current block. - BasicBlockPtr t_block = BasicBlock::Make(arena_); - BasicBlockPtr f_block = BasicBlock::Make(arena_); - Succ(parent, t_block); - Succ(parent, f_block); - - VisitExpr(ite->true_branch, t_block); - VisitExpr(ite->false_branch, f_block); - - // All subsequent bindings (and/or the body expr) will be in a new basic block. - BasicBlockPtr next = BasicBlock::Make(arena_); - Succ(t_block, next); - Succ(f_block, next); - parent = next; - } else if (const MatchNode* match = AsIgnoringOnDevice(inner_let_node->value)) { - // Same as above but one for each pattern. - std::vector clause_blocks; - BasicBlockPtr next = BasicBlock::Make(arena_); - for (const Clause& clause : match->clauses) { - BasicBlockPtr clause_block = BasicBlock::Make(arena_); - Succ(parent, clause_block); - Succ(clause_block, next); - VisitExpr(clause->rhs, clause_block); - } - parent = next; - } - - expr = inner_let_node->body; - } - - VisitExpr(expr, parent); - } - - void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent) { - // TODO(@altanh): is there a way of making this work? - LOG(FATAL) << "If expressions should be bound to variables."; - } - - void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent) { - // TODO(@altanh): same as If - LOG(FATAL) << "Match expressions should be bound to variables."; - } + void VisitExpr_(const FunctionNode* f, BasicBlockPtr parent) final; + void VisitExpr_(const LetNode* let_node, BasicBlockPtr parent) final; + void VisitExpr_(const IfNode* if_node, BasicBlockPtr parent); + void VisitExpr_(const MatchNode* match_node, BasicBlockPtr parent); DEFAULT_CFG(VarNode); DEFAULT_CFG(GlobalVarNode); @@ -265,10 +200,6 @@ class ControlFlowGraph::Creator : private ExprFunctor { public: - VarSet VisitExpr_(const VarNode* var_node) { return {GetRef(var_node)}; } - - VarSet VisitExpr_(const CallNode* call_node) { - VarSet use = VisitExpr(call_node->op); - for (const Expr& arg : call_node->args) { - VarSet arg_use = VisitExpr(arg); - use.insert(arg_use.begin(), arg_use.end()); - } - return use; - } - - VarSet VisitExpr_(const TupleNode* tuple_node) { - VarSet use; - for (const Expr& field : tuple_node->fields) { - VarSet field_use = VisitExpr(field); - use.insert(field_use.begin(), field_use.end()); - } - return use; - } - - VarSet VisitExpr_(const TupleGetItemNode* get_node) { return VisitExpr(get_node->tuple); } - - VarSet VisitExpr_(const IfNode* if_node) { return VisitExpr(if_node->cond); } - - VarSet VisitExpr_(const MatchNode* match_node) { return VisitExpr(match_node->data); } + VarSet VisitExpr_(const VarNode* var_node); + VarSet VisitExpr_(const CallNode* call_node); + VarSet VisitExpr_(const TupleNode* tuple_node); + VarSet VisitExpr_(const TupleGetItemNode* get_node); + VarSet VisitExpr_(const IfNode* if_node); + VarSet VisitExpr_(const MatchNode* match_node); VarSet VisitExpr_(const ConstructorNode* cons_node) { return {}; } - VarSet VisitExpr_(const GlobalVarNode* gvar_node) { return {}; } - VarSet VisitExpr_(const ConstantNode* const_node) { return {}; } - VarSet VisitExpr_(const OpNode* op_node) { return {}; } + VarSet VisitExpr_(const FunctionNode* func_node) { return {}; } }; /*! @@ -325,37 +235,11 @@ struct UseDefAnalysis { VarUseCollector use_collector; - static UseDefAnalysis Analyze(const CFG& cfg) { - UseDefAnalysis a; - - // One pass is sufficient. - for (auto it = cfg.reverse_post_order.begin(); it != cfg.reverse_post_order.end(); ++it) { - const CFG::NodePtr& node = *it; - if (const LetNode* let_node = AsIgnoringOnDevice(node->expr)) { - a.use[node] = a.use_collector.VisitExpr(let_node->value); - a.def[node] = let_node->var; - } else { - a.use[node] = a.use_collector.VisitExpr(node->expr); - a.def[node] = Var(); - } - } - - return a; - } + static UseDefAnalysis Analyze(const CFG& cfg); }; /*! \brief Returns whether \p a and \p b are the same set of vars. */ -bool SetEqual(const VarSet& a, const VarSet& b) { - if (a.size() != b.size()) { - return false; - } - for (auto& xa : a) { - if (!b.count(xa)) { - return false; - } - } - return true; -} +bool SetEqual(const VarSet& a, const VarSet& b); /*! * \brief Analysis that collects the live variables before and after each node. @@ -376,49 +260,7 @@ struct LivenessAnalysis { * \param use_def Use-def analysis of \p cfg. * \return LivenessAnalysis */ - static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def) { - LivenessAnalysis a; - std::list worklist; - - // Initialize worklist to post-order traversal for quick convergence. - worklist.insert(worklist.end(), cfg.reverse_post_order.rbegin(), cfg.reverse_post_order.rend()); - - // See https://lambda.uta.edu/cse5317/notes/node40.html for an overview of the algorithm. - auto visitor = [&](const CFG::NodePtr n) { - VarSet old_in_n = a.live_in[n]; - VarSet old_out_n = a.live_out[n]; - - a.live_in[n] = use_def.use.at(n); - for (const Var& v : a.live_out[n]) { - if (!v.same_as(use_def.def.at(n))) { - a.live_in[n].insert(v); - } - } - - a.live_out[n] = VarSet(); - for (const CFG::NodePtr& s : n->GetSucc()) { - a.live_out[n].insert(a.live_in[s].begin(), a.live_in[s].end()); - } - - if (SetEqual(old_in_n, a.live_in[n]) && SetEqual(old_out_n, a.live_out[n])) { - // No need to update the worklist. - } else { - // Add predecessor nodes back to worklist (no need to add successors, since each node's - // in/out sets are not dependent on its predecessors). - for (const CFG::NodePtr& p : n->GetPred()) { - worklist.push_back(p); - } - } - }; - - while (!worklist.empty()) { - const CFG::NodePtr n = worklist.front(); - worklist.pop_front(); - visitor(n); - } - - return a; - } + static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def); }; /*! @@ -481,38 +323,7 @@ class KillInserter : public ExprMutator { // // However, these limitations are unlikely to cause large leaks in practice. - Expr VisitExpr_(const LetNode* let_node) override { - Expr expr = GetRef(let_node); - LetList ll; - - while (const LetNode* inner_let_node = expr.as()) { - ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); - - ICHECK(!inner_let_node->value.as()) << "aliasing should have been eliminated."; - ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; - - const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); - - const VarSet& li = lva_->live_in.at(n); - const VarSet& lo = lva_->live_out.at(n); - - // Killed vars = live in - live out. - VarSet kills; - for (const Var& v : li) { - if (!lo.count(v)) { - kills.insert(v); - } - } - - for (const Var& v : kills) { - ll.Push(Call(Op::Get("memory.kill"), {v})); - } - - expr = inner_let_node->body; - } - - return ll.Get(VisitExpr(expr)); - } + Expr VisitExpr_(const LetNode* let_node); private: const ControlFlowGraph* cfg_; @@ -529,93 +340,13 @@ class AliasEliminator : public MixedModeMutator { public: using MixedModeMutator::VisitExpr_; - Expr VisitExpr_(const LetNode* let_node) override { - Expr expr = GetRef(let_node); - LetList ll; - std::vector aliased_vars; - - while (const LetNode* inner_let_node = expr.as()) { - const Var& var = inner_let_node->var; - const Expr& val = inner_let_node->value; - bool aliased = false; - ICHECK(!alias_.count(var)); - - if (const VarNode* alias_of_n = AsIgnoringOnDevice(val)) { - alias_[var] = Downcast(VisitExpr_(alias_of_n)); - aliased = true; - } else if (AsIgnoringOnDevice(val)) { - // Copying to the same device is aliasing. - // WARNING: this must be kept in sync with the VM compiler logic in - // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). - Expr unwrapped = IgnoreOnDevice(val); - DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); - if (copy_props.body.defined()) { - if (copy_props.src_virtual_device->device_type() == - copy_props.dst_virtual_device->device_type() && - copy_props.src_virtual_device->virtual_device_id == - copy_props.dst_virtual_device->virtual_device_id) { - Expr to_copy = Downcast(unwrapped)->args[0]; - if (const VarNode* alias_of_n = to_copy.as()) { - alias_[var] = Downcast(VisitExpr_(alias_of_n)); - aliased = true; - } - } - } - } - - if (!aliased) { - ll.Push(var, VisitExpr(val)); - } else { - aliased_vars.push_back(var); - } - - expr = inner_let_node->body; - } - - Expr body = ll.Get(VisitExpr(expr)); - - // remove the aliased vars so that alias_ only tracks things in scope - for (const Var& v : aliased_vars) { - alias_.erase(v); - } - - return body; - } - - Expr VisitExpr_(const VarNode* var_node) override { - Var var = GetRef(var_node); - if (alias_.count(var)) { - return alias_[var]; - } - return var; - } - - Expr VisitExpr_(const FunctionNode* func_node) override { - Expr new_body = VisitExpr(func_node->body); - return WithFields(GetRef(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body); - } + Expr VisitExpr_(const LetNode* let_node) override; + Expr VisitExpr_(const VarNode* var_node) override; + Expr VisitExpr_(const FunctionNode* func_node) override; // The only register-level aliasing that occurs in Match expressions is when // the deconstructed expression is a Var, and the matched pattern is also a Var. - Expr VisitExpr_(const MatchNode* match_node) override { - if (const VarNode* data_var_node = AsIgnoringOnDevice(match_node->data)) { - Var data_var = Downcast(VisitExpr_(data_var_node)); - std::vector new_clauses; - for (const Clause& clause : match_node->clauses) { - const PatternVarNode* pv_node = nullptr; - if ((pv_node = clause->lhs.as())) { - alias_[pv_node->var] = data_var; - } - new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); - if (pv_node) { - alias_.erase(pv_node->var); - } - } - return Match(data_var, new_clauses, match_node->complete, match_node->span); - } else { - return ExprMutator::VisitExpr_(match_node); - } - } + Expr VisitExpr_(const MatchNode* match_node) override; private: /*! @@ -625,22 +356,10 @@ class AliasEliminator : public MixedModeMutator { std::unordered_map alias_; }; -Pass ManifestLifetimes() { - auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { - f = Downcast(AliasEliminator().Mutate(f)); - Arena arena; - ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f); - UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); - LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); - KillInserter ki(&cfg, &lva); - Function nf = Downcast(ki.Mutate(f)); - return nf; - }; - return CreateFunctionPass(pass_func, 0, "ManifestLifetimes", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes").set_body_typed(ManifestLifetimes); +Pass ManifestLifetimes(); } // namespace transform } // namespace relay } // namespace tvm + +#endif // TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ diff --git a/tests/python/relay/aot/test_used_memory_annotator.py b/tests/python/relay/aot/test_used_memory_annotator.py new file mode 100644 index 000000000000..c882984a3835 --- /dev/null +++ b/tests/python/relay/aot/test_used_memory_annotator.py @@ -0,0 +1,194 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name + +""" +Testing for the pass that annotates used memory for each primitive +Relay function. +""" + +import tvm +from tvm import relay +from tvm.relay.expr_functor import ExprVisitor + + +def AnnotateUsedMemory(): + return relay.transform._ffi_api.AnnotateUsedMemory() + + +class CheckUsedMemoryAnnotation(ExprVisitor): + """ + Check that the annotations on each function in the graph match + what is expected. + """ + + def __init__(self, expected_annotations): + self.expected_annotations = expected_annotations + super().__init__() + + def visit_function(self, fn): + if "Primitive" in fn.attrs: + assert ( + "used_memory" in fn.attrs + ), "Primitive function does not have used_memory annotation." + + assert len(self.expected_annotations) > 0, "Not all expected annotations were compared" + + expected_mem = self.expected_annotations.pop(0) + actual_mem = fn.attrs["used_memory"] + assert expected_mem == actual_mem, ( + f"Expected used memory annotation {expected_mem} " + f"did not match actual annotation {actual_mem}" + ) + super().visit_function(fn) + + +def _check_used_memory_annotations(mod, expected_annotations): + mod = relay.transform.InferType()(mod) + mod = relay.transform.ToANormalForm()(mod) + mod = relay.transform.InferType()(mod) + mod = AnnotateUsedMemory()(mod) + + CheckUsedMemoryAnnotation(expected_annotations).visit(mod["main"].body) + + +def _create_primitive_function(expr): + func = relay.Function(relay.analysis.free_vars(expr), expr) + func = func.with_attr("Primitive", 1) + return func + + +def test_simple(): + """ + Test simple graph with one primitive function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + expected_annotations = [2 * (1 * 2 * 2 * 4)] + _check_used_memory_annotations(mod, expected_annotations) + + +def test_multiple_functions(): + """ + Test a graph with multiple primitive functions. + """ + + def get_inner_func(ifm_shape): + x = relay.var("x", shape=ifm_shape, dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC") + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 8, 8, 2), dtype="int8") + x = get_inner_func((1, 8, 8, 2)) + x = relay.Call(x, [ifm]) + y = get_inner_func((1, 7, 7, 2)) + y = relay.Call(y, [x]) + z = get_inner_func((1, 6, 6, 2)) + z = relay.Call(z, [y]) + mod = tvm.IRModule.from_expr(z) + + expected_annotations = [ + (1 * 8 * 8 * 2) + (1 * 7 * 7 * 2), + (1 * 7 * 7 * 2) + (1 * 6 * 6 * 2), + (1 * 6 * 6 * 2) + (1 * 5 * 5 * 2), + ] + _check_used_memory_annotations(mod, expected_annotations) + + +def test_mixed_data_types(): + """ + Test a graph with a primitive function that has mixed datatypes. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 2), dtype="int16") + x = relay.cast(x, dtype="uint32") + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 2), dtype="int16") + x = get_inner_func() + x = relay.Call(x, [ifm]) + mod = tvm.IRModule.from_expr(x) + + expected_annotations = [ + (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4, + ] + _check_used_memory_annotations(mod, expected_annotations) + + +def test_parallel_function_call(): + """ + Test a graph when the results of two functions are concatenated + into a single result. The second function will also have the result + of the first function alive so will be annotated with a larger + "used memory" value. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.reshape(x, newshape=(1, 4, 30)) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + x = relay.Call(get_inner_func(), [ifm]) + y = relay.Call(get_inner_func(), [ifm]) + z = relay.concatenate([x, y], axis=0) + mod = tvm.IRModule.from_expr(z) + + expected_annotations = [ + (1 * 4 * 5 * 6) + (1 * 4 * 30), + # the output tensor from the previous function is also alive + (1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30), + ] + _check_used_memory_annotations(mod, expected_annotations) + + +def test_composite_inner_function(): + """ + Tests the typical BYOC use case where a primitive function + contains a composite function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2)) + x = relay.Function(relay.analysis.free_vars(x), x) + x = x.with_attr("Composite", "my_composite_func") + + y = relay.var("y", shape=(1, 2, 2, 4), dtype="int8") + z = relay.Call(x, [y]) + z = _create_primitive_function(z) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + x = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(x) + + expected_annotations = [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)] + _check_used_memory_annotations(mod, expected_annotations) From ab605fba32aa0e38d85618964648c61fd80f7e1c Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 9 May 2022 09:31:56 +0000 Subject: [PATCH 2/5] small fix to file description Change-Id: I0e460f6cf43f9b12ffa5fc66fcb68e55304daeb2 --- src/relay/backend/aot/annotate_used_memory.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/relay/backend/aot/annotate_used_memory.cc b/src/relay/backend/aot/annotate_used_memory.cc index 25d1624635e0..0eeca8f3a86a 100644 --- a/src/relay/backend/aot/annotate_used_memory.cc +++ b/src/relay/backend/aot/annotate_used_memory.cc @@ -18,8 +18,8 @@ */ /*! - * \file src/relay/backend/aot/memory_pressure.cc - * \brief Analyses the memory pressure at external function callsites. + * \file src/relay/backend/aot/annotate_used_memory.cc + * \brief Analyzes the memory pressure at the callsite of primitive functions. */ #include From aed1281aae5954e124494b42eaab7e50490c5c57 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Thu, 19 May 2022 18:51:44 +0000 Subject: [PATCH 3/5] Various improvements addressing comments In addition, a new "io_used_memory" annotation is added to the main function which refers to the total size of the IO tensors in the provided module, enabling these to be discounted from memory pressure calculations where necessary. Change-Id: Iafe9c85d7fc69c77a2115ed4efe7645160387c86 --- include/tvm/relay/transform.h | 9 + src/relay/backend/annotate_used_memory.cc | 222 +++++++++++++++ src/relay/backend/aot/annotate_used_memory.cc | 156 ----------- src/relay/backend/aot_executor_codegen.cc | 1 - ...fest_lifetimes.cc => liveness_analysis.cc} | 143 +--------- ...nifest_lifetimes.h => liveness_analysis.h} | 107 +------ src/relay/backend/vm/manifest_lifetimes.cc | 260 ++++++++++++++++++ .../relay/aot/test_used_memory_annotator.py | 130 ++++++++- 8 files changed, 620 insertions(+), 408 deletions(-) create mode 100644 src/relay/backend/annotate_used_memory.cc delete mode 100644 src/relay/backend/aot/annotate_used_memory.cc rename src/relay/backend/{manifest_lifetimes.cc => liveness_analysis.cc} (60%) rename src/relay/backend/{manifest_lifetimes.h => liveness_analysis.h} (68%) create mode 100644 src/relay/backend/vm/manifest_lifetimes.cc diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index b592265c74cd..d1bdff8f7a31 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -556,6 +556,15 @@ TVM_DLL Pass PlanDevices(CompilationConfig config); */ TVM_DLL Pass FlattenAtrousConv(); +/*! + * \brief Annotates the memory usage of each primitive function by analyzing the liveness + * of the input/output tensors at each function callsite and calculating the total amount of + * memory these tensors require. This is added as a "used_memory" annotation to the function + * in question. In addition, the containing function is annotated with an "io_used_memory" + * annotation which refers to the total memory required for the IO tensors. + */ +TVM_DLL Pass AnnotateUsedMemory(); + } // namespace transform /*! diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc new file mode 100644 index 000000000000..5cc0a9a7d0ff --- /dev/null +++ b/src/relay/backend/annotate_used_memory.cc @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/annotate_used_memory.cc + * \brief Analyzes the used memory at the callsite of primitive functions. + */ + +#include +#include +#include + +#include +#include + +#include "../transforms/device_aware_visitors.h" +#include "./liveness_analysis.h" +#include "./utils.h" + +namespace tvm { +namespace relay { +namespace backend { + +/*! + * \brief Annotates the memory usage of each primitive function by analyzing the liveness + * of the input/output tensors at each function callsite and calculating the total amount of + * memory these tensors require. This is added as a "used_memory" annotation to the function + * in question. In addition, the containing function is annotated with an "io_used_memory" + * annotation which refers to the total memory required for the IO tensors. + * + * A simple example: + * + * Before: + * def @main(%input: Tensor[(1, 2, 2, 4), int8]) -> Tensor[(1, 2, 2, 4), int8] { + * let %x_0 = fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1) -> Tensor[(1, 2, 2, 4), int8] { + * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) + * }; + * let %x_1 = %x_0(%input); + * %x_1 + * } + * + * After: + * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] { + * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2, + * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) + * }; + * let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input); + * %x_1 + * } + * + * Note that in the simple example above io_used_memory and used_memory are the same since there + * is only one primitive function. + */ +class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { + public: + AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, + const transform::LivenessAnalysis& lva) + : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} + + /*! + * \brief Mutates the input function. In addition, an "io_used_memory" annotation is + * added to the input function which refers to the total size required for the IO + * tensors. + */ + Function operator()(const Function& func) { + uint64_t io_used_memory = 0; + + // Inputs + for (const Var& param : func->params) { + Type type = param->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + io_used_memory += CalculateRelayExprSizeBytes(type); + } + + // Outputs + Type type = func->body->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + io_used_memory += CalculateRelayExprSizeBytes(type); + + Expr new_func_body = VisitExpr(func->body); + Function new_func = WithFields(func, func->params, new_func_body); + return WithAttr(std::move(new_func), "io_used_memory", + tvm::IntImm(tvm::DataType::UInt(64), io_used_memory)); + } + + /*! + * \brief Establish which let bindings have primitive function values. + */ + std::pair PreVisitLetBinding_(const Var& var, const Expr& value) { + if (const auto* func_node = value.as()) { + ICHECK(func_node->attrs.HasNonzeroAttr(attr::kPrimitive)) + << "Expect top-level functions to be primitive."; + let_bound_prim_func_.insert(var); + } + return DeviceAwareExprMutator::PreVisitLetBinding_(var, value); + } + + /*! + * \brief Visit let nodes and perform one of two actions depending on their value: + * + * 1. CallNode - Calculate "used_memory" annotation value at the callsite of + * primitive functions. + * + * 2. FunctionNode - Annotate functions with "used_memory" annotation based on the + * previous analysis at the callsite. + * + */ + Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { + Var let_var = post_let_node->var; + Expr let_value = IgnoreOnDevice(post_let_node->value); + + if (let_value->IsInstance()) { + Call callsite = Downcast(let_value); + if (CheckPrimitiveFunctionCall(callsite)) { + Var call_op = Downcast(callsite->op); + + // Find all the vars that are live at the callsite. This is done by merging the + // in and out varset's and then removing the var that references the primitive + // function itself since we don't want this included in the calculation. + const transform::ControlFlowGraph::NodePtr cfg_node = + control_flow_graph_.let_map.at(GetRef(pre_let_node)); + transform::VarSet live_tensors = liveness_.live_in.at(cfg_node); + const transform::VarSet& live_out = liveness_.live_out.at(cfg_node); + live_tensors.insert(live_out.begin(), live_out.end()); + live_tensors.erase(call_op); + + // Calculate size of live tensors and store to allow annotation when the function + // gets visited. + uint64_t used_memory = 0; + for (const auto& var : live_tensors) { + Type type = var->checked_type(); + ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + used_memory += CalculateRelayExprSizeBytes(type); + } + used_memory_annotations_[call_op] = used_memory; + } + } else if (let_value->IsInstance()) { + Function func = Downcast(let_value); + ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end()) + << "Could not find used_memory value for primitive function bound at " + << let_var->name_hint(); + uint64_t used_memory = used_memory_annotations_[let_var]; + used_memory_annotations_.erase(let_var); + Function new_func = WithAttr(std::move(func), "used_memory", + tvm::IntImm(tvm::DataType::UInt(64), used_memory)); + return Let(let_var, new_func, post_let_node->body, post_let_node->span); + } + + return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); + } + + private: + /*! + * \brief Check if a call is a primitive function callsite. + */ + bool CheckPrimitiveFunctionCall(const Call& callsite) { + if (const auto* var_node = callsite->op.as()) { + Var var = GetRef(var_node); + if (let_bound_prim_func_.find(var) != let_bound_prim_func_.end()) { + return true; + } + } + return false; + } + + /*! \brief Control flow graph representation of the main function. */ + transform::ControlFlowGraph control_flow_graph_; + /*! \brief Liveness analysis of the main function. */ + transform::LivenessAnalysis liveness_; + /*! \brief Var's that reference primitive functions. */ + std::unordered_set let_bound_prim_func_; + /*! \brief Stores the calculated used_memory values so they can be annotated on the relevant + * function. */ + std::unordered_map used_memory_annotations_; +}; + +} // namespace backend + +namespace transform { + +Pass AnnotateUsedMemory() { + runtime::TypedPackedFunc pass_func = [=](IRModule mod, + PassContext ctx) { + GlobalVar gv = mod->GetGlobalVar("main"); + Function main_func = Downcast(mod->Lookup("main")); + + // Perform liveness analysis to determine what tensors are 'live' at each functions callsite. + support::Arena arena; + ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + + auto new_main_func = backend::AnnotateUsedMemoryMutator(mod, cfg, lva)(main_func); + if (!new_main_func.same_as(main_func)) { + mod->Update(gv, new_main_func); + } + return mod; + }; + return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"}); +} + +TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/src/relay/backend/aot/annotate_used_memory.cc b/src/relay/backend/aot/annotate_used_memory.cc deleted file mode 100644 index 0eeca8f3a86a..000000000000 --- a/src/relay/backend/aot/annotate_used_memory.cc +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ - -/*! - * \file src/relay/backend/aot/annotate_used_memory.cc - * \brief Analyzes the memory pressure at the callsite of primitive functions. - */ - -#include -#include - -#include "../../transforms/device_aware_visitors.h" -#include "../manifest_lifetimes.h" - -namespace tvm { -namespace relay { -namespace backend { -namespace aot { - -/*! - * \brief Annotates the memory usage of each primitive function by analysing the liveness - * of the input/output tensors at the function callsite and calculating the total amount of - * memory these tensors require. - */ -class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { - public: - AnnotateUsedMemoryMutator(const IRModule& module, const transform::ControlFlowGraph& cfg, - const transform::LivenessAnalysis& lva) - : DeviceAwareExprMutator(module), control_flow_graph_(cfg), liveness_(lva) {} - - /*! - * \brief Get the memory required for a primitive Relay function by calculating the total - * bytes of the live tensors at the callsite of the function. - * - * \param live_tensors The tensors that are live when the function is called. - * \return int The total number of bytes a function requires. - */ - int GetMemoryUsage(const transform::VarSet& live_tensors) { - Array types_stack = {}; - int memory_usage = 0; - - for (const Var& var : live_tensors) { - Type var_type = var->checked_type(); - ICHECK(var_type.defined()) << "InferTypes pass should be run before AnnotateUsedMemory pass."; - types_stack.push_back(var_type); - } - - while (!types_stack.empty()) { - Type current_type = types_stack.back(); - types_stack.pop_back(); - - if (const auto* tt_node = current_type.as()) { - for (const Type& type : tt_node->fields) { - types_stack.push_back(type); - } - continue; - } else if (const auto* ft_node = current_type.as()) { - types_stack.push_back(ft_node->ret_type); - continue; - } - - const auto* tt_node = current_type.as(); - ICHECK(tt_node) << "Expected TensorTypeNode but was " << current_type->GetTypeKey(); - int total_tensor_bytes = GetTensorBytes(tt_node); - memory_usage += total_tensor_bytes; - } - return memory_usage; - } - - /*! - * \brief Get the number of bytes a tensor requires. - * - * \param tensor_type_node The checked type of the tensor. - * \return int The number of bytes required. - */ - int GetTensorBytes(const TensorTypeNode* tensor_type_node) { - PrimExpr size = tensor_type_node->Size(); - const auto* size_int_imm = size.as(); - ICHECK(size_int_imm) << "Expected tensor size to be an IntImmNode but was " - << size->GetTypeKey(); - - int total_size = size_int_imm->value; - int dtype_bytes = tensor_type_node->dtype.bytes(); - return total_size * dtype_bytes; - } - - Expr PostVisitLet_(const LetNode* pre_let_node, const LetNode* post_let_node) override { - if (const auto* func_node = pre_let_node->value.as()) { - const auto let_bound_values = control_flow_graph_.let_map; - const transform::ControlFlowGraph::NodePtr cfg_node = - let_bound_values.at(GetRef(pre_let_node)); - const transform::VarSet& liveness_out = liveness_.live_out.at(cfg_node); - int memory_pressure = GetMemoryUsage(liveness_out); - Function new_func = WithAttr(std::move(GetRef(func_node)), "used_memory", - tvm::Integer(memory_pressure)); - return Let(post_let_node->var, new_func, post_let_node->body, post_let_node->span); - } - return DeviceAwareExprMutator::PostVisitLet_(pre_let_node, post_let_node); - } - - private: - /*! \brief Control flow graph representation of the main function. */ - transform::ControlFlowGraph control_flow_graph_; - /*! \brief Liveness analysis of the main function. */ - transform::LivenessAnalysis liveness_; -}; - -} // namespace aot -} // namespace backend - -namespace transform { -Pass AnnotateUsedMemory() { - runtime::TypedPackedFunc pass_func = [=](IRModule mod, - PassContext ctx) { - GlobalVar gv = mod->GetGlobalVar("main"); - Function main_func = Downcast(mod->Lookup("main")); - - // Perform liveness analysis to determine what tensors are 'live' at each functions - // callsite. - support::Arena arena; - ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, main_func); - UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); - LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); - - auto new_main_body = - backend::aot::AnnotateUsedMemoryMutator(mod, cfg, lva).VisitExpr(main_func->body); - if (!new_main_body.same_as(main_func->body)) { - Function new_main_func = WithFields(main_func, main_func->params, new_main_body); - mod->Update(gv, new_main_func); - } - return mod; - }; - return CreateModulePass(pass_func, 0, "AnnotateUsedMemory", {"ToANormalForm", "InferType"}); -} - -TVM_REGISTER_GLOBAL("relay._transform.AnnotateUsedMemory").set_body_typed(AnnotateUsedMemory); - -} // namespace transform -} // namespace relay -} // namespace tvm diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 0e61cf4b21c6..5020e79714b2 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -48,7 +48,6 @@ #include "../op/call/call.h" #include "../op/memory/device_copy.h" #include "../transforms/device_aware_visitors.h" -#include "./aot/annotate_used_memory.cc" #include "./name_transforms.h" #include "./te_compiler.h" #include "./utils.h" diff --git a/src/relay/backend/manifest_lifetimes.cc b/src/relay/backend/liveness_analysis.cc similarity index 60% rename from src/relay/backend/manifest_lifetimes.cc rename to src/relay/backend/liveness_analysis.cc index 6114cf97a8de..52db9e6a4c23 100644 --- a/src/relay/backend/manifest_lifetimes.cc +++ b/src/relay/backend/liveness_analysis.cc @@ -18,12 +18,12 @@ */ /*! - * \file src/relay/backend/manifest_lifetimes.cc - * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in - * ANF and post-memory-lowering (explicit manifestation of allocations). + * \file src/relay/backend/liveness_analysis.cc + * \brief Analysis that collects the live variables before and after each node. + * NOTE: the input IR should be in ANF. */ -#include "manifest_lifetimes.h" +#include "./liveness_analysis.h" #include #include @@ -227,141 +227,6 @@ LivenessAnalysis LivenessAnalysis::Analyze(const ControlFlowGraph& cfg, return a; } -Expr KillInserter::VisitExpr_(const LetNode* let_node) { - Expr expr = GetRef(let_node); - LetList ll; - - while (const LetNode* inner_let_node = expr.as()) { - ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); - - ICHECK(!inner_let_node->value.as()) << "aliasing should have been eliminated."; - ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; - - const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); - - const VarSet& li = lva_->live_in.at(n); - const VarSet& lo = lva_->live_out.at(n); - - // Killed vars = live in - live out. - VarSet kills; - for (const Var& v : li) { - if (!lo.count(v)) { - kills.insert(v); - } - } - - for (const Var& v : kills) { - ll.Push(Call(Op::Get("memory.kill"), {v})); - } - - expr = inner_let_node->body; - } - - return ll.Get(VisitExpr(expr)); -} - -Expr AliasEliminator::VisitExpr_(const LetNode* let_node) { - Expr expr = GetRef(let_node); - LetList ll; - std::vector aliased_vars; - - while (const LetNode* inner_let_node = expr.as()) { - const Var& var = inner_let_node->var; - const Expr& val = inner_let_node->value; - bool aliased = false; - ICHECK(!alias_.count(var)); - - if (const VarNode* alias_of_n = AsIgnoringOnDevice(val)) { - alias_[var] = Downcast(VisitExpr_(alias_of_n)); - aliased = true; - } else if (AsIgnoringOnDevice(val)) { - // Copying to the same device is aliasing. - // WARNING: this must be kept in sync with the VM compiler logic in - // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). - Expr unwrapped = IgnoreOnDevice(val); - DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); - if (copy_props.body.defined()) { - if (copy_props.src_virtual_device->device_type() == - copy_props.dst_virtual_device->device_type() && - copy_props.src_virtual_device->virtual_device_id == - copy_props.dst_virtual_device->virtual_device_id) { - Expr to_copy = Downcast(unwrapped)->args[0]; - if (const VarNode* alias_of_n = to_copy.as()) { - alias_[var] = Downcast(VisitExpr_(alias_of_n)); - aliased = true; - } - } - } - } - - if (!aliased) { - ll.Push(var, VisitExpr(val)); - } else { - aliased_vars.push_back(var); - } - - expr = inner_let_node->body; - } - - Expr body = ll.Get(VisitExpr(expr)); - - // remove the aliased vars so that alias_ only tracks things in scope - for (const Var& v : aliased_vars) { - alias_.erase(v); - } - - return body; -} - -Expr AliasEliminator::VisitExpr_(const VarNode* var_node) { - Var var = GetRef(var_node); - if (alias_.count(var)) { - return alias_[var]; - } - return var; -} - -Expr AliasEliminator::VisitExpr_(const FunctionNode* func_node) { - Expr new_body = VisitExpr(func_node->body); - return WithFields(GetRef(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body); -} - -Expr AliasEliminator::VisitExpr_(const MatchNode* match_node) { - if (const VarNode* data_var_node = AsIgnoringOnDevice(match_node->data)) { - Var data_var = Downcast(VisitExpr_(data_var_node)); - std::vector new_clauses; - for (const Clause& clause : match_node->clauses) { - const PatternVarNode* pv_node = nullptr; - if ((pv_node = clause->lhs.as())) { - alias_[pv_node->var] = data_var; - } - new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); - if (pv_node) { - alias_.erase(pv_node->var); - } - } - return Match(data_var, new_clauses, match_node->complete, match_node->span); - } else { - return ExprMutator::VisitExpr_(match_node); - } -} - -Pass ManifestLifetimes() { - auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { - f = Downcast(AliasEliminator().Mutate(f)); - Arena arena; - ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f); - UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); - LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); - KillInserter ki(&cfg, &lva); - Function nf = Downcast(ki.Mutate(f)); - return nf; - }; - return CreateFunctionPass(pass_func, 0, "ManifestLifetimes", {}); -} - -TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes").set_body_typed(ManifestLifetimes); - } // namespace transform } // namespace relay } // namespace tvm diff --git a/src/relay/backend/manifest_lifetimes.h b/src/relay/backend/liveness_analysis.h similarity index 68% rename from src/relay/backend/manifest_lifetimes.h rename to src/relay/backend/liveness_analysis.h index 5826fcf1ce65..4e9514056b86 100644 --- a/src/relay/backend/manifest_lifetimes.h +++ b/src/relay/backend/liveness_analysis.h @@ -18,13 +18,13 @@ */ /*! - * \file src/relay/backend/manifest_lifetimes.h - * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in - * ANF and post-memory-lowering (explicit manifestation of allocations). + * \file src/relay/backend/liveness_analysis.h + * \brief Analysis that collects the live variables before and after each node. + * NOTE: the input IR should be in ANF. */ -#ifndef TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ -#define TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ +#ifndef TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ +#define TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ #include @@ -263,103 +263,8 @@ struct LivenessAnalysis { static LivenessAnalysis Analyze(const ControlFlowGraph& cfg, const UseDefAnalysis& use_def); }; -/*! - * \brief Helper class to insert kills using liveness information. - */ -class KillInserter : public ExprMutator { - public: - KillInserter(const ControlFlowGraph* cfg, const LivenessAnalysis* lva) : cfg_(cfg), lva_(lva) {} - - // Limitations - // ----------- - // (1) For simplicity, we only insert kills when visiting Let bindings, and always emit the kill - // as a single subsequent binding. This is slightly inaccurate; for example, if the condition of - // an If is dead after the test, we can immediately kill the condition in each branch: - // let %x = if (%dead_cond) { - // let %_0 = memory.kill(%dead_cond); - // ... - // } else { - // let %_1 = memory.kill(%dead_cond); - // ... - // } - // as opposed to: - // let %x = if (%dead_cond) ... - // let %_0 = memory.kill(%dead_cond); - // - // (2) Killed variables are calculated as live in - live out, which misses variables that are - // actually dead but not in a live-in set. Example: - // @f(%x: int, %y: int, %c: bool) { - // let %w = if (%c) { - // let %z = %y + %y; - // %z - // } else { - // %y - // }; - // %w - // } - // After inserting kills: - // @f(%x: int, %y: int, %c: bool) { - // /* %x is always dead, so never in any live in or live out set */ - // let %w = if (%c) { - // let %z = %y + %y; - // let %_0 = memory.kill(%y); - // %z - // } else { - // %y - // /* %y is dead at this point */ - // }; - // let %_1 = memory.kill(%c); - // /* no kill for %y since it's not in the live-in of %w AND %w isn't a let binding */ - // %w - // } - // - // (3) When the result expr of an If branch is a variable, and this expr is the last use of the - // var, we cannot "kill" the var since it is being returned. The VM compiler also emits a Move - // instruction to merge the branch results, which creates another ObjectRef to the Object held - // by the var. The var is also not in the subsequent live-in (since it is indeed dead by this - // point), so it won't be killed. An example can be seen in the previous code block for (2), where - // %y is not killed if the else-branch is taken (and indeed it can be killed, as %w is mapped to - // a new register and holds a fresh reference to the object referenced by %y). - // - // However, these limitations are unlikely to cause large leaks in practice. - - Expr VisitExpr_(const LetNode* let_node); - - private: - const ControlFlowGraph* cfg_; - const LivenessAnalysis* lva_; -}; - -/*! - * \brief Helper class to eliminate variable aliasing. This pass anticipates the VM compiler's - * register aliasing behavior so as to avoid killing vars that point to the same register. An - * alternative approach would be to track aliasing within the VM compiler itself, so that kill - * instructions are only emitted when all aliases are killed. - */ -class AliasEliminator : public MixedModeMutator { - public: - using MixedModeMutator::VisitExpr_; - - Expr VisitExpr_(const LetNode* let_node) override; - Expr VisitExpr_(const VarNode* var_node) override; - Expr VisitExpr_(const FunctionNode* func_node) override; - - // The only register-level aliasing that occurs in Match expressions is when - // the deconstructed expression is a Var, and the matched pattern is also a Var. - Expr VisitExpr_(const MatchNode* match_node) override; - - private: - /*! - * \brief Mapping of var -> var it's an alias of. Note that transitive aliases - * (e.g. x = 0; y = x; z = y) are mapped to the non-aliased variable (in this example "x"). - */ - std::unordered_map alias_; -}; - -Pass ManifestLifetimes(); - } // namespace transform } // namespace relay } // namespace tvm -#endif // TVM_RELAY_BACKEND_MANIFEST_LIFETIMES_H_ +#endif // TVM_RELAY_BACKEND_LIVENESS_ANALYSIS_H_ diff --git a/src/relay/backend/vm/manifest_lifetimes.cc b/src/relay/backend/vm/manifest_lifetimes.cc new file mode 100644 index 000000000000..486e06320345 --- /dev/null +++ b/src/relay/backend/vm/manifest_lifetimes.cc @@ -0,0 +1,260 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/backend/vm/manifest_lifetimes.cc + * \brief Analysis and explicit manifestation of variable lifetimes. NOTE: the input IR should be in + * ANF and post-memory-lowering (explicit manifestation of allocations). + */ + +#include + +#include "../../../support/arena.h" +#include "../../op/memory/device_copy.h" +#include "../../transforms/device_aware_visitors.h" +#include "../../transforms/let_list.h" +#include "../liveness_analysis.h" + +namespace tvm { +namespace relay { +namespace transform { + +/*! + * \brief Helper class to insert kills using liveness information. + */ +class KillInserter : public ExprMutator { + public: + KillInserter(const ControlFlowGraph* cfg, const LivenessAnalysis* lva) : cfg_(cfg), lva_(lva) {} + + // Limitations + // ----------- + // (1) For simplicity, we only insert kills when visiting Let bindings, and always emit the kill + // as a single subsequent binding. This is slightly inaccurate; for example, if the condition of + // an If is dead after the test, we can immediately kill the condition in each branch: + // let %x = if (%dead_cond) { + // let %_0 = memory.kill(%dead_cond); + // ... + // } else { + // let %_1 = memory.kill(%dead_cond); + // ... + // } + // as opposed to: + // let %x = if (%dead_cond) ... + // let %_0 = memory.kill(%dead_cond); + // + // (2) Killed variables are calculated as live in - live out, which misses variables that are + // actually dead but not in a live-in set. Example: + // @f(%x: int, %y: int, %c: bool) { + // let %w = if (%c) { + // let %z = %y + %y; + // %z + // } else { + // %y + // }; + // %w + // } + // After inserting kills: + // @f(%x: int, %y: int, %c: bool) { + // /* %x is always dead, so never in any live in or live out set */ + // let %w = if (%c) { + // let %z = %y + %y; + // let %_0 = memory.kill(%y); + // %z + // } else { + // %y + // /* %y is dead at this point */ + // }; + // let %_1 = memory.kill(%c); + // /* no kill for %y since it's not in the live-in of %w AND %w isn't a let binding */ + // %w + // } + // + // (3) When the result expr of an If branch is a variable, and this expr is the last use of the + // var, we cannot "kill" the var since it is being returned. The VM compiler also emits a Move + // instruction to merge the branch results, which creates another ObjectRef to the Object held + // by the var. The var is also not in the subsequent live-in (since it is indeed dead by this + // point), so it won't be killed. An example can be seen in the previous code block for (2), where + // %y is not killed if the else-branch is taken (and indeed it can be killed, as %w is mapped to + // a new register and holds a fresh reference to the object referenced by %y). + // + // However, these limitations are unlikely to cause large leaks in practice. + + Expr VisitExpr_(const LetNode* let_node) override { + Expr expr = GetRef(let_node); + LetList ll; + + while (const LetNode* inner_let_node = expr.as()) { + ll.Push(inner_let_node->var, VisitExpr(inner_let_node->value)); + + ICHECK(!inner_let_node->value.as()) << "aliasing should have been eliminated."; + ICHECK(cfg_->let_map.count(expr)) << "all Let exprs should be mapped in the CFG"; + + const ControlFlowGraph::NodePtr n = cfg_->let_map.at(expr); + + const VarSet& li = lva_->live_in.at(n); + const VarSet& lo = lva_->live_out.at(n); + + // Killed vars = live in - live out. + VarSet kills; + for (const Var& v : li) { + if (!lo.count(v)) { + kills.insert(v); + } + } + + for (const Var& v : kills) { + ll.Push(Call(Op::Get("memory.kill"), {v})); + } + + expr = inner_let_node->body; + } + + return ll.Get(VisitExpr(expr)); + } + + private: + const ControlFlowGraph* cfg_; + const LivenessAnalysis* lva_; +}; + +/*! + * \brief Helper class to eliminate variable aliasing. This pass anticipates the VM compiler's + * register aliasing behavior so as to avoid killing vars that point to the same register. An + * alternative approach would be to track aliasing within the VM compiler itself, so that kill + * instructions are only emitted when all aliases are killed. + */ +class AliasEliminator : public MixedModeMutator { + public: + using MixedModeMutator::VisitExpr_; + + Expr VisitExpr_(const LetNode* let_node) override { + Expr expr = GetRef(let_node); + LetList ll; + std::vector aliased_vars; + + while (const LetNode* inner_let_node = expr.as()) { + const Var& var = inner_let_node->var; + const Expr& val = inner_let_node->value; + bool aliased = false; + ICHECK(!alias_.count(var)); + + if (const VarNode* alias_of_n = AsIgnoringOnDevice(val)) { + alias_[var] = Downcast(VisitExpr_(alias_of_n)); + aliased = true; + } else if (AsIgnoringOnDevice(val)) { + // Copying to the same device is aliasing. + // WARNING: this must be kept in sync with the VM compiler logic in + // src/relay/backend/vm/compiler.cc, line 541, in DeviceAwareVisitExpr_(const CallNode*). + Expr unwrapped = IgnoreOnDevice(val); + DeviceCopyProps copy_props = GetDeviceCopyProps(unwrapped); + if (copy_props.body.defined()) { + if (copy_props.src_virtual_device->device_type() == + copy_props.dst_virtual_device->device_type() && + copy_props.src_virtual_device->virtual_device_id == + copy_props.dst_virtual_device->virtual_device_id) { + Expr to_copy = Downcast(unwrapped)->args[0]; + if (const VarNode* alias_of_n = to_copy.as()) { + alias_[var] = Downcast(VisitExpr_(alias_of_n)); + aliased = true; + } + } + } + } + + if (!aliased) { + ll.Push(var, VisitExpr(val)); + } else { + aliased_vars.push_back(var); + } + + expr = inner_let_node->body; + } + + Expr body = ll.Get(VisitExpr(expr)); + + // remove the aliased vars so that alias_ only tracks things in scope + for (const Var& v : aliased_vars) { + alias_.erase(v); + } + + return body; + } + + Expr VisitExpr_(const VarNode* var_node) override { + Var var = GetRef(var_node); + if (alias_.count(var)) { + return alias_[var]; + } + return var; + } + + Expr VisitExpr_(const FunctionNode* func_node) override { + Expr new_body = VisitExpr(func_node->body); + return WithFields(GetRef(func_node), /*opt_params=*/NullOpt, /*opt_body=*/new_body); + } + + // The only register-level aliasing that occurs in Match expressions is when + // the deconstructed expression is a Var, and the matched pattern is also a Var. + Expr VisitExpr_(const MatchNode* match_node) override { + if (const VarNode* data_var_node = AsIgnoringOnDevice(match_node->data)) { + Var data_var = Downcast(VisitExpr_(data_var_node)); + std::vector new_clauses; + for (const Clause& clause : match_node->clauses) { + const PatternVarNode* pv_node = nullptr; + if ((pv_node = clause->lhs.as())) { + alias_[pv_node->var] = data_var; + } + new_clauses.push_back(Clause(clause->lhs, VisitExpr(clause->rhs))); + if (pv_node) { + alias_.erase(pv_node->var); + } + } + return Match(data_var, new_clauses, match_node->complete, match_node->span); + } else { + return ExprMutator::VisitExpr_(match_node); + } + } + + private: + /*! + * \brief Mapping of var -> var it's an alias of. Note that transitive aliases + * (e.g. x = 0; y = x; z = y) are mapped to the non-aliased variable (in this example "x"). + */ + std::unordered_map alias_; +}; + +Pass ManifestLifetimes() { + auto pass_func = [](Function f, IRModule m, PassContext pc) -> Function { + f = Downcast(AliasEliminator().Mutate(f)); + Arena arena; + ControlFlowGraph cfg = ControlFlowGraph::Create(&arena, f); + UseDefAnalysis use_def = UseDefAnalysis::Analyze(cfg); + LivenessAnalysis lva = LivenessAnalysis::Analyze(cfg, use_def); + KillInserter ki(&cfg, &lva); + Function nf = Downcast(ki.Mutate(f)); + return nf; + }; + return CreateFunctionPass(pass_func, 0, "ManifestLifetimes", {}); +} + +TVM_REGISTER_GLOBAL("relay._transform.ManifestLifetimes").set_body_typed(ManifestLifetimes); + +} // namespace transform +} // namespace relay +} // namespace tvm diff --git a/tests/python/relay/aot/test_used_memory_annotator.py b/tests/python/relay/aot/test_used_memory_annotator.py index c882984a3835..698efab3c6cb 100644 --- a/tests/python/relay/aot/test_used_memory_annotator.py +++ b/tests/python/relay/aot/test_used_memory_annotator.py @@ -36,8 +36,9 @@ class CheckUsedMemoryAnnotation(ExprVisitor): what is expected. """ - def __init__(self, expected_annotations): + def __init__(self, expected_annotations, expected_io_annotation): self.expected_annotations = expected_annotations + self.expected_io_annotation = expected_io_annotation super().__init__() def visit_function(self, fn): @@ -56,14 +57,20 @@ def visit_function(self, fn): ) super().visit_function(fn) + def __call__(self, fn): + assert ( + fn.attrs["io_used_memory"] == self.expected_io_annotation + ), "Expected IO annotation did not match." + self.visit(fn.body) -def _check_used_memory_annotations(mod, expected_annotations): + +def _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation): mod = relay.transform.InferType()(mod) mod = relay.transform.ToANormalForm()(mod) mod = relay.transform.InferType()(mod) mod = AnnotateUsedMemory()(mod) - CheckUsedMemoryAnnotation(expected_annotations).visit(mod["main"].body) + CheckUsedMemoryAnnotation(expected_annotations, expected_io_annotation)(mod["main"]) def _create_primitive_function(expr): @@ -88,7 +95,8 @@ def get_inner_func(): mod = tvm.IRModule.from_expr(call) expected_annotations = [2 * (1 * 2 * 2 * 4)] - _check_used_memory_annotations(mod, expected_annotations) + expected_io_annotation = 2 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) def test_multiple_functions(): @@ -116,7 +124,8 @@ def get_inner_func(ifm_shape): (1 * 7 * 7 * 2) + (1 * 6 * 6 * 2), (1 * 6 * 6 * 2) + (1 * 5 * 5 * 2), ] - _check_used_memory_annotations(mod, expected_annotations) + expected_io_annotation = (1 * 8 * 8 * 2) + (1 * 5 * 5 * 2) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) def test_mixed_data_types(): @@ -138,7 +147,8 @@ def get_inner_func(): expected_annotations = [ (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4, ] - _check_used_memory_annotations(mod, expected_annotations) + expected_io_annotation = (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4 + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) def test_parallel_function_call(): @@ -166,7 +176,105 @@ def get_inner_func(): # the output tensor from the previous function is also alive (1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30), ] - _check_used_memory_annotations(mod, expected_annotations) + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 60) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_many_different_parallel_calls(): + """ + Test a graph that calls many different functions in parallel. + + input + / | \ + prim_func_1 prim_func_2 prim_func_3 + \ | / + prim_func_4 + """ + + def get_inner_func_1(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.tanh(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_2(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(1, 1), layout="NHWC") + x = _create_primitive_function(x) + return x + + def get_inner_func_3(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.abs(x) + x = relay.nn.relu(x) + x = relay.exp(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_4(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + y = relay.var("y", shape=(1, 4, 5, 6), dtype="int8") + z = relay.var("z", shape=(1, 4, 5, 6), dtype="int8") + out = relay.concatenate([x, y, z], axis=3) + out = _create_primitive_function(out) + return out + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + x = relay.Call(get_inner_func_1(), [ifm]) + y = relay.Call(get_inner_func_2(), [ifm]) + z = relay.Call(get_inner_func_3(), [ifm]) + a = relay.Call(get_inner_func_4(), [x, y, z]) + mod = tvm.IRModule.from_expr(a) + + expected_annotations = [ + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), + # output from prim_func_1 is also still alive + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), + # outputs from prim_func_1 and prim_func_2 are also still alive + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18), + ] + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_nested_branches(): + """ + Tests a graph with branches that also branch. + + input + / \ + / \ + prim_func_1 prim_func_2 + / \ + / \ + prim_func_3 prim_func_4 + """ + + def get_generic_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.relu(x) + return _create_primitive_function(x) + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + a = relay.Call(get_generic_inner_func(), [ifm]) + b = relay.Call(get_generic_inner_func(), [ifm]) + c = relay.Call(get_generic_inner_func(), [b]) + d = relay.Call(get_generic_inner_func(), [b]) + out = relay.concatenate([a, c, d], axis=3) + mod = tvm.IRModule.from_expr(out) + + expected_annotations = [ + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + # output from prim_func_1 is also still alive + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + # output from prim_func_1 is also still alive + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + # outputs from prim_func_1 and prim_func_3 are also still alive + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + ] + expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 2 * 2 * 12) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) def test_composite_inner_function(): @@ -177,18 +285,18 @@ def test_composite_inner_function(): def get_inner_func(): x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") - x = relay.nn.max_pool2d(x, pool_size=(2, 2)) + x = relay.nn.max_pool2d(x, pool_size=(2, 2), layout="NHWC") x = relay.Function(relay.analysis.free_vars(x), x) x = x.with_attr("Composite", "my_composite_func") y = relay.var("y", shape=(1, 2, 2, 4), dtype="int8") z = relay.Call(x, [y]) - z = _create_primitive_function(z) - return x + return _create_primitive_function(z) ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") x = relay.Call(get_inner_func(), [ifm]) mod = tvm.IRModule.from_expr(x) expected_annotations = [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)] - _check_used_memory_annotations(mod, expected_annotations) + expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 1 * 1 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) From 93c06724113d34ede5aab8607e77df10833e04e9 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 30 May 2022 20:24:57 +0000 Subject: [PATCH 4/5] addressing comments Change-Id: I00f5ba80d5e004076e4c27d39bec143178b3b1dd --- include/tvm/relay/transform.h | 11 +- src/relay/backend/annotate_used_memory.cc | 34 ++-- .../{aot => }/test_used_memory_annotator.py | 166 ++++++++++++++++-- 3 files changed, 176 insertions(+), 35 deletions(-) rename tests/python/relay/{aot => }/test_used_memory_annotator.py (65%) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index d1bdff8f7a31..7faee3100844 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -557,11 +557,12 @@ TVM_DLL Pass PlanDevices(CompilationConfig config); TVM_DLL Pass FlattenAtrousConv(); /*! - * \brief Annotates the memory usage of each primitive function by analyzing the liveness - * of the input/output tensors at each function callsite and calculating the total amount of - * memory these tensors require. This is added as a "used_memory" annotation to the function - * in question. In addition, the containing function is annotated with an "io_used_memory" - * annotation which refers to the total memory required for the IO tensors. + * \brief Annotates the minimum required memory of each primitive function callsite by analyzing + * the liveness of the input/output tensors at each function callsite and calculating the total + * amount of memory these tensors require. This is added as a "used_memory" annotation to the + * function in question as a list of the number of bytes for each callsite. In addition, the + * containing function is annotated with an "io_used_memory" annotation which refers to the total + * memory required for the IO tensors. */ TVM_DLL Pass AnnotateUsedMemory(); diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc index 5cc0a9a7d0ff..34acf37e88e8 100644 --- a/src/relay/backend/annotate_used_memory.cc +++ b/src/relay/backend/annotate_used_memory.cc @@ -30,6 +30,7 @@ #include #include "../transforms/device_aware_visitors.h" +#include "../transforms/pass_utils.h" #include "./liveness_analysis.h" #include "./utils.h" @@ -38,11 +39,12 @@ namespace relay { namespace backend { /*! - * \brief Annotates the memory usage of each primitive function by analyzing the liveness - * of the input/output tensors at each function callsite and calculating the total amount of - * memory these tensors require. This is added as a "used_memory" annotation to the function - * in question. In addition, the containing function is annotated with an "io_used_memory" - * annotation which refers to the total memory required for the IO tensors. + * \brief Annotates the minimum required memory of each primitive function callsite by analyzing + * the liveness of the input/output tensors at each function callsite and calculating the total + * amount of memory these tensors require. This is added as a "used_memory" annotation to the + * function in question as a list of the number of bytes for each callsite. In addition, the + * containing function is annotated with an "io_used_memory" annotation which refers to the total + * memory required for the IO tensors. * * A simple example: * @@ -57,8 +59,9 @@ namespace backend { * * After: * def @main(%input: Tensor[(1, 2, 2, 4), int8], io_used_memory=32) -> Tensor[(1, 2, 2, 4), int8] { - * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=32) -> Tensor[(1, 2, 2, - * 4), int8] { nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) + * let %x_0: fn (%x: Tensor[(1, 2, 2, 4), int8], Primitive=1, used_memory=[32]) -> Tensor[(1, 2, + * 2, 4), int8] { + * nn.max_pool2d(%x, pool_size=[1, 1], padding=[0, 0, 0, 0]) * }; * let %x_1: Tensor[(1, 2, 2, 4), int8] = %x_0(%input); * %x_1 @@ -85,12 +88,14 @@ class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { for (const Var& param : func->params) { Type type = param->checked_type(); ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; io_used_memory += CalculateRelayExprSizeBytes(type); } // Outputs Type type = func->body->checked_type(); ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; io_used_memory += CalculateRelayExprSizeBytes(type); Expr new_func_body = VisitExpr(func->body); @@ -146,19 +151,22 @@ class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { for (const auto& var : live_tensors) { Type type = var->checked_type(); ICHECK(type.defined()) << "InferType pass should be run before AnnotateUsedMemory."; + ICHECK(!IsDynamic(type)) << "AnnotateUsedMemory does not support dynamic shapes."; used_memory += CalculateRelayExprSizeBytes(type); } - used_memory_annotations_[call_op] = used_memory; + IntImm annotation(DataType::UInt(64), used_memory); + used_memory_annotations_[call_op].push_back(annotation); } } else if (let_value->IsInstance()) { Function func = Downcast(let_value); ICHECK(used_memory_annotations_.find(let_var) != used_memory_annotations_.end()) << "Could not find used_memory value for primitive function bound at " << let_var->name_hint(); - uint64_t used_memory = used_memory_annotations_[let_var]; + Array used_memory = used_memory_annotations_[let_var]; used_memory_annotations_.erase(let_var); + Function new_func = WithAttr(std::move(func), "used_memory", - tvm::IntImm(tvm::DataType::UInt(64), used_memory)); + Array(used_memory.rbegin(), used_memory.rend())); return Let(let_var, new_func, post_let_node->body, post_let_node->span); } @@ -185,9 +193,9 @@ class AnnotateUsedMemoryMutator : public transform::DeviceAwareExprMutator { transform::LivenessAnalysis liveness_; /*! \brief Var's that reference primitive functions. */ std::unordered_set let_bound_prim_func_; - /*! \brief Stores the calculated used_memory values so they can be annotated on the relevant - * function. */ - std::unordered_map used_memory_annotations_; + /*! \brief Stores the calculated uint64 used_memory values so they can be annotated on the + * relevant function. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> used_memory_annotations_; }; } // namespace backend diff --git a/tests/python/relay/aot/test_used_memory_annotator.py b/tests/python/relay/test_used_memory_annotator.py similarity index 65% rename from tests/python/relay/aot/test_used_memory_annotator.py rename to tests/python/relay/test_used_memory_annotator.py index 698efab3c6cb..e339152294b6 100644 --- a/tests/python/relay/aot/test_used_memory_annotator.py +++ b/tests/python/relay/test_used_memory_annotator.py @@ -21,6 +21,8 @@ Relay function. """ +import pytest + import tvm from tvm import relay from tvm.relay.expr_functor import ExprVisitor @@ -50,7 +52,7 @@ def visit_function(self, fn): assert len(self.expected_annotations) > 0, "Not all expected annotations were compared" expected_mem = self.expected_annotations.pop(0) - actual_mem = fn.attrs["used_memory"] + actual_mem = [int(x) for x in fn.attrs["used_memory"]] assert expected_mem == actual_mem, ( f"Expected used memory annotation {expected_mem} " f"did not match actual annotation {actual_mem}" @@ -94,7 +96,9 @@ def get_inner_func(): call = relay.Call(get_inner_func(), [ifm]) mod = tvm.IRModule.from_expr(call) - expected_annotations = [2 * (1 * 2 * 2 * 4)] + expected_annotations = [ + [2 * (1 * 2 * 2 * 4)], + ] expected_io_annotation = 2 * (1 * 2 * 2 * 4) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -120,9 +124,9 @@ def get_inner_func(ifm_shape): mod = tvm.IRModule.from_expr(z) expected_annotations = [ - (1 * 8 * 8 * 2) + (1 * 7 * 7 * 2), - (1 * 7 * 7 * 2) + (1 * 6 * 6 * 2), - (1 * 6 * 6 * 2) + (1 * 5 * 5 * 2), + [(1 * 8 * 8 * 2) + (1 * 7 * 7 * 2)], + [(1 * 7 * 7 * 2) + (1 * 6 * 6 * 2)], + [(1 * 6 * 6 * 2) + (1 * 5 * 5 * 2)], ] expected_io_annotation = (1 * 8 * 8 * 2) + (1 * 5 * 5 * 2) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -145,7 +149,7 @@ def get_inner_func(): mod = tvm.IRModule.from_expr(x) expected_annotations = [ - (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4, + [(1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4], ] expected_io_annotation = (1 * 2 * 2 * 2) * 2 + (1 * 2 * 2 * 2) * 4 _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -172,9 +176,9 @@ def get_inner_func(): mod = tvm.IRModule.from_expr(z) expected_annotations = [ - (1 * 4 * 5 * 6) + (1 * 4 * 30), + [(1 * 4 * 5 * 6) + (1 * 4 * 30)], # the output tensor from the previous function is also alive - (1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30), + [(1 * 4 * 5 * 6) + (1 * 4 * 30) + (1 * 4 * 30)], ] expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 60) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -227,12 +231,12 @@ def get_inner_func_4(): mod = tvm.IRModule.from_expr(a) expected_annotations = [ - (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], # output from prim_func_1 is also still alive - (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], # outputs from prim_func_1 and prim_func_2 are also still alive - (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6), - (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18), + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18)], ] expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 5 * 18) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -265,13 +269,13 @@ def get_generic_inner_func(): mod = tvm.IRModule.from_expr(out) expected_annotations = [ - (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], # output from prim_func_1 is also still alive - (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], # output from prim_func_1 is also still alive - (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], # outputs from prim_func_1 and prim_func_3 are also still alive - (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4), + [(1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4) + (1 * 2 * 2 * 4)], ] expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 2 * 2 * 12) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) @@ -297,6 +301,134 @@ def get_inner_func(): x = relay.Call(get_inner_func(), [ifm]) mod = tvm.IRModule.from_expr(x) - expected_annotations = [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)] + expected_annotations = [ + [(1 * 2 * 2 * 4) + (1 * 1 * 1 * 4)], + ] expected_io_annotation = (1 * 2 * 2 * 4) + (1 * 1 * 1 * 4) _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_multiple_calls_to_same_function(): + """ + Tests the case when there are multiple calls to the same function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + inner_func = get_inner_func() + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call1 = relay.Call(inner_func, [ifm]) + call2 = relay.Call(inner_func, [call1]) + mod = tvm.IRModule.from_expr(call2) + + expected_annotations = [[2 * (1 * 2 * 2 * 4), 2 * (1 * 2 * 2 * 4)]] + expected_io_annotation = 2 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_parallel_calls_to_same_function(): + """ + Test parallel calls to the same function. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + inner_func = get_inner_func() + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call1 = relay.Call(inner_func, [ifm]) + call2 = relay.Call(inner_func, [ifm]) + concat = relay.concatenate([call1, call2], axis=0) + mod = tvm.IRModule.from_expr(concat) + + expected_annotations = [[2 * (1 * 2 * 2 * 4), 3 * (1 * 2 * 2 * 4)]] + expected_io_annotation = 3 * (1 * 2 * 2 * 4) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_parallel_calls_with_non_ifm_input(): + """ + Test a graph that calls many different functions in parallel where + the input is not the input to the function. + + y = f(x) + / | \ + z0 = g0(y) ... zi = gi(y) + \ | / + concat + """ + + def get_inner_func_1(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.tanh(x) + x = _create_primitive_function(x) + return x + + def get_inner_func_2(): + x = relay.var("x", shape=(1, 4, 5, 6), dtype="int8") + x = relay.nn.max_pool2d(x, pool_size=(2, 2)) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 4, 5, 6), dtype="int8") + y = relay.Call(get_inner_func_1(), [ifm]) + g = get_inner_func_2() + + no_calls = 20 + z = [relay.Call(g, [y]) for _ in range(0, no_calls)] + out = relay.concatenate(z, axis=3) + mod = tvm.IRModule.from_expr(out) + + expected_annotations = [ + [(1 * 4 * 5 * 6) + (1 * 4 * 5 * 6)], + [(1 * 4 * 5 * 6) + (1 * 4 * 4 * 5) * i for i in range(1, no_calls + 1)], + ] + expected_io_annotation = (1 * 4 * 5 * 6) + (1 * 4 * 4 * (5 * no_calls)) + _check_used_memory_annotations(mod, expected_annotations, expected_io_annotation) + + +def test_dynamic_io_tensor_not_supported(): + """ + Test to check dynamic IO tensor error. + """ + + def get_inner_func(): + x = relay.var("x", shape=(1, 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, relay.Any()), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + err_rgx = r"AnnotateUsedMemory does not support dynamic shapes" + with pytest.raises(tvm.TVMError, match=err_rgx): + _check_used_memory_annotations(mod, [], []) + + +def test_dynamic_callsite_tensor_not_supported(): + """ + Test to check dynamic callsite tensor error. + """ + + def get_inner_func(): + x = relay.var("x", shape=(relay.Any(), 2, 2, 4), dtype="int8") + x = relay.nn.max_pool2d(x) + x = _create_primitive_function(x) + return x + + ifm = relay.var("input", shape=(1, 2, 2, 4), dtype="int8") + call = relay.Call(get_inner_func(), [ifm]) + mod = tvm.IRModule.from_expr(call) + + err_rgx = r"AnnotateUsedMemory does not support dynamic shapes" + with pytest.raises(tvm.TVMError, match=err_rgx): + _check_used_memory_annotations(mod, [], []) From 89f752389a63c5cd7d9201363cb80e434bd8f556 Mon Sep 17 00:00:00 2001 From: Luke Hutton Date: Mon, 6 Jun 2022 09:25:21 +0000 Subject: [PATCH 5/5] add note for dynamic shapes Change-Id: If6409e2953addfc880bcc6d95083b78bdf5a23d0 --- include/tvm/relay/transform.h | 3 +++ src/relay/backend/annotate_used_memory.cc | 3 +++ 2 files changed, 6 insertions(+) diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index 7faee3100844..1fef02557e09 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -563,6 +563,9 @@ TVM_DLL Pass FlattenAtrousConv(); * function in question as a list of the number of bytes for each callsite. In addition, the * containing function is annotated with an "io_used_memory" annotation which refers to the total * memory required for the IO tensors. + * + * Note: This pass does not support dynamic shapes, it is the users responsibility to check this + * pass isn't applied where dynamic shapes may be input. */ TVM_DLL Pass AnnotateUsedMemory(); diff --git a/src/relay/backend/annotate_used_memory.cc b/src/relay/backend/annotate_used_memory.cc index 34acf37e88e8..ad370c73ad1e 100644 --- a/src/relay/backend/annotate_used_memory.cc +++ b/src/relay/backend/annotate_used_memory.cc @@ -46,6 +46,9 @@ namespace backend { * containing function is annotated with an "io_used_memory" annotation which refers to the total * memory required for the IO tensors. * + * Note: This pass does not support dynamic shapes, it is the users responsibility to check this + * pass isn't applied where dynamic shapes may be input. + * * A simple example: * * Before: