diff --git a/include/tvm/ir.h b/include/tvm/ir.h index 95e01382dd98..989802326ae4 100644 --- a/include/tvm/ir.h +++ b/include/tvm/ir.h @@ -152,6 +152,12 @@ constexpr const char* coproc_scope = "coproc_scope"; constexpr const char* coproc_uop_scope = "coproc_uop_scope"; /*! \brief Mark the scope as volatile access for certain handle. */ constexpr const char* volatile_scope = "volatile_scope"; +/*! + * \brief Mark the scope as generated by extern primitive. + * such scope can contain arbitrary ir program and we need to be careful + * when make certain assumptions about the structure of the program. + */ +constexpr const char* extern_scope = "extern_scope"; /*! * \brief Mark the scope as when computation start to happen * This can hint some code generator to create a new function for compute. diff --git a/src/op/extern_op.cc b/src/op/extern_op.cc index 9b302f6e2504..e83f97b14652 100644 --- a/src/op/extern_op.cc +++ b/src/op/extern_op.cc @@ -130,7 +130,7 @@ Stmt ExternOpNode::BuildProvide( const Stage& stage, const std::unordered_map& dom_map) const { CHECK_EQ(stage->op.operator->(), this); - Stmt ret = this->body; + Stmt ret = AttrStmt::make(make_zero(Int(32)), attr::extern_scope, 0, this->body); auto f_push_bind = [&ret](Buffer buffer, Tensor tensor) { Array bind_spec; Array tuple; diff --git a/src/pass/storage_rewrite.cc b/src/pass/storage_rewrite.cc index 9d47a64f8837..5e7abdda2112 100644 --- a/src/pass/storage_rewrite.cc +++ b/src/pass/storage_rewrite.cc @@ -41,27 +41,32 @@ class LinearAccessPatternFinder final : public IRVisitor { struct StmtEntry { // The statment const Node* stmt; - // Scope used for allocation. - StorageScope alloc_scope; + // The index in the linear_seq_ to point to end of the nested scope. + // This is only set to non-zero if stmt is a nested scope. + // if offset > 0, means this is the begin, the end entry is current_index + offset + // if offset < 0, means this is the end, the begin entry is current_index + offset + int64_t scope_pair_offset{0}; // The buffer variables this statment touched. std::vector touched; }; + // The scope of each allocation + struct AllocEntry { + // Scope used for allocation. + StorageScope storage_scope; + // scope level + size_t level{0}; + // allocation stmt + const Allocate* alloc{nullptr}; + }; - // Get linear access pattern. - std::vector GetLinearSeq(const Stmt& s) { - this->Visit(s); - return std::move(linear_seq_); - } void Visit_(const Allocate* op) final { size_t level = scope_.size(); const Variable* buf = op->buffer_var.get(); - CHECK(!alloc_scope_level_.count(buf)); - alloc_scope_level_[buf] = level; - StmtEntry e; - e.stmt = op; - e.alloc_scope = GetScope(buf); - e.touched.push_back(buf); - linear_seq_.emplace_back(std::move(e)); + auto it = alloc_info_.find(buf); + CHECK(it != alloc_info_.end()); + CHECK(it->second.alloc == nullptr); + it->second.alloc = op; + it->second.level = level; IRVisitor::Visit_(op); } void Visit_(const Store* op) final { @@ -70,9 +75,10 @@ class LinearAccessPatternFinder final : public IRVisitor { IRVisitor::Visit_(op); // Add write access. const Variable* buf = op->buffer_var.get(); - auto it = alloc_scope_level_.find(buf); - if (it != alloc_scope_level_.end()) { - scope_[it->second].touched.push_back(buf); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + CHECK_LT(it->second.level, scope_.size()); + scope_[it->second.level].touched.push_back(buf); } StmtEntry e = scope_.back(); scope_.pop_back(); @@ -96,11 +102,11 @@ class LinearAccessPatternFinder final : public IRVisitor { // Add write access. IRVisitor::Visit_(op); const Variable* buf = op->buffer_var.get(); - auto it = alloc_scope_level_.find(buf); - if (it != alloc_scope_level_.end()) { - CHECK_LT(it->second, scope_.size()) + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + CHECK_LT(it->second.level, scope_.size()) << "Load memory in places other than store."; - scope_[it->second].touched.push_back(buf); + scope_[it->second.level].touched.push_back(buf); } } void Visit_(const Call* op) final { @@ -113,10 +119,11 @@ class LinearAccessPatternFinder final : public IRVisitor { } void Visit_(const Variable* buf) final { // Directly reference to the variable count as a read. - auto it = alloc_scope_level_.find(buf); - if (it != alloc_scope_level_.end()) { - CHECK_LT(it->second, scope_.size()) << " buf=" << buf->name_hint; - scope_[it->second].touched.push_back(buf); + auto it = alloc_info_.find(buf); + if (it != alloc_info_.end() && it->second.alloc) { + CHECK_LT(it->second.level, scope_.size()) + << " buf=" << buf->name_hint; + scope_[it->second.level].touched.push_back(buf); } } template @@ -124,13 +131,20 @@ class LinearAccessPatternFinder final : public IRVisitor { scope_.push_back(StmtEntry()); StmtEntry e; e.stmt = op; + int64_t begin_index = static_cast(linear_seq_.size()); // before scope. linear_seq_.push_back(e); IRVisitor::Visit_(op); // after scope. e.touched = std::move(scope_.back().touched); scope_.pop_back(); + int64_t end_index = static_cast(linear_seq_.size()); + CHECK_GT(end_index, begin_index); + e.scope_pair_offset = begin_index - end_index; linear_seq_.push_back(e); + // record the pointer to end index. + CHECK_NE(end_index, 0U); + linear_seq_[begin_index].scope_pair_offset = end_index - begin_index; } void Visit_(const AttrStmt* op) final { // Only record the outer most thread extent. @@ -138,9 +152,11 @@ class LinearAccessPatternFinder final : public IRVisitor { in_thread_env_ = true; VisitNewScope(op); in_thread_env_ = false; + } else if (op->attr_key == attr::extern_scope) { + VisitNewScope(op); } else if (op->attr_key == attr::storage_scope) { const Variable* buf = op->node.as(); - storage_scope_[buf] = + alloc_info_[buf].storage_scope = StorageScope::make(op->value.as()->value); IRVisitor::Visit_(op); } else { @@ -155,36 +171,156 @@ class LinearAccessPatternFinder final : public IRVisitor { VisitNewScope(op); } + // linearized access sequence. + std::vector linear_seq_; + // The storage scope of each buffer + std::unordered_map alloc_info_; + private: - // Get storage scope of buffer. - StorageScope GetScope(const Variable* buf) const { - auto it = storage_scope_.find(buf); - CHECK(it != storage_scope_.end()); - return it->second; - } // Whether already in thread env. bool in_thread_env_{false}; - // linearized access sequence. - std::vector linear_seq_; // The scope stack. std::vector scope_; - // The storage scope of each buffer - std::unordered_map storage_scope_; - // buffer -> allocated scope level in the IR. - std::unordered_map alloc_scope_level_; +}; + +// Verify if the statement can be run safely via inplace fashion +// +// Detect pattern: dst[index] = f(src[index]) +// +// WARNING: the current detection algorithm cannot handle the case +// when a location in an array is written multiple times +// +// For example, the following program will pass the check, +// but we cannot make A and B to be the same array. +// +// A[0] = B[0] + 1 +// A[0] = B[0] + 1 +// +// The high level code generator needs to ensure that the generated +// code only write each location of the target array once. +// +// This is the case with IR generated by the current compute schedule. +// We explicitly return false if we find there is an extern block +// which can be arbitrary IR. +// +// Neve-the-less, inplace detector should be used with care in mind. +// We may also consider introduce a condition checker that checks +// if every index only visited once for an absolute sufficient condition. +// +// The code after inplace transformation is no longer idempotent. +// +class InplaceOpVerifier : public IRVisitor { + public: + bool Check(const Node* stmt, + const Variable* dst, + const Variable* src) { + dst_ = dst; + src_ = src; + result_ = true; + if (stmt->is_type()) { + Visit_(static_cast(stmt)); + } else if (stmt->is_type()) { + Visit_(static_cast(stmt)); + } else if (stmt->is_type()) { + Visit_(static_cast(stmt)); + } else if (stmt->is_type()) { + Visit_(static_cast(stmt)); + } else { + return false; + } + return result_; + } + + using IRVisitor::Visit_; + + void Visit(const NodeRef& e) final { + if (!result_) return; + IRVisitor::Visit(e); + } + + void Visit_(const Variable* op) final { + // assume all opaque access is unsafe + if (op == dst_ || op == src_) { + result_ = false; return; + } + } + + void Visit_(const Store* op) final { + ++mem_nest_; + this->Visit(op->index); + --mem_nest_; + if (op->buffer_var.get() == dst_) { + store_ = op; + this->Visit(op->value); + this->Visit(op->predicate); + store_ = nullptr; + } else { + this->Visit(op->value); + this->Visit(op->predicate); + } + } + + void Visit_(const AttrStmt* op) final { + // always reject extern code + if (op->attr_key == attr::extern_scope || + op->attr_key == attr::volatile_scope) { + result_ = false; return; + } + IRVisitor::Visit_(op); + } + + void Visit_(const Load* op) final { + const Variable* buf = op->buffer_var.get(); + // cannot read from dst_ (no reduction) + if (buf == dst_) { + result_ = false; return; + } + // do not allow indirect memory load + if (mem_nest_ != 0) { + result_ = false; return; + } + if (src_ == buf) { + if (store_ == nullptr || + store_->value.type() != op->type || + !ir::Equal(store_->index, op->index)) { + result_ = false; return; + } + } + ++mem_nest_; + IRVisitor::Visit_(op); + --mem_nest_; + } + + + private: + // result of the check + bool result_{true}; + // destination memory + const Variable* dst_; + // source variable + const Variable* src_; + // counter of load, + // it is not safe to inplace when there is nested load like A[B[i]] + int mem_nest_{0}; + // The current store to be inspected + const Store* store_{nullptr}; }; // Planner to plan and rewrite memory allocation. class StoragePlanRewriter : public IRMutator { public: using StmtEntry = LinearAccessPatternFinder::StmtEntry; + using AllocEntry = LinearAccessPatternFinder::AllocEntry; - Stmt Rewrite(Stmt stmt) { - std::vector seq = - LinearAccessPatternFinder().GetLinearSeq(stmt); - this->FindFreeLocation(seq); - this->PlanMemory(seq); + Stmt Rewrite(Stmt stmt, bool detect_inplace) { + detect_inplace_ = detect_inplace; + // plan the rewrite + LinearAccessPatternFinder finder; + finder.Visit(stmt); + this->LivenessAnalysis(finder.linear_seq_); + this->PlanMemory(finder.linear_seq_, finder.alloc_info_); this->PrepareNewAlloc(); + // start rewrite stmt = this->Mutate(stmt); if (attach_map_.count(nullptr)) { std::vector nest; @@ -308,7 +444,6 @@ class StoragePlanRewriter : public IRMutator { } private: - // Alllocate entry of node. struct StorageEntry { // The scope that this alloc attaches after // For shared/local memory it is beginning of the thread extent. @@ -332,6 +467,16 @@ class StoragePlanRewriter : public IRMutator { // the address becomes alloc_var + sizeof(elem_type) * elem_offset; uint64_t elem_offset{0}; }; + + // Alllocate entry of node. + // Event entry in liveness analysis + struct EventEntry { + // variables we generate + std::vector gen; + // variables we kill + std::vector kill; + }; + Stmt MakeAttach(const std::vector& svec, Stmt body) { std::vector nest; @@ -461,16 +606,29 @@ class StoragePlanRewriter : public IRMutator { << "Allocation exceed bound of memory tag " << e->scope.to_string(); } } - // Find the free location of each varaible. - // Just do a reverse linear scan. - void FindFreeLocation(const std::vector& seq) { + // Liveness analysis to find gen and kill point of each variable. + void LivenessAnalysis(const std::vector& seq) { + // find kill point, do a reverse linear scan. std::unordered_set touched; for (size_t i = seq.size(); i != 0; --i) { const StmtEntry& s = seq[i - 1]; for (const Variable* buffer : s.touched) { if (!touched.count(buffer)) { touched.insert(buffer); - free_loc_[i - 1].push_back(buffer); + event_map_[s.stmt].kill.push_back(buffer); + } + } + } + // find gen point, do forward scan + touched.clear(); + for (size_t i = 0; i < seq.size(); ++i) { + int64_t offset = seq[i].scope_pair_offset; + if (offset < 0) continue; + const StmtEntry& s = seq[i + offset]; + for (const Variable* buffer : s.touched) { + if (!touched.count(buffer)) { + touched.insert(buffer); + event_map_[s.stmt].gen.push_back(buffer); } } } @@ -500,14 +658,66 @@ class StoragePlanRewriter : public IRMutator { } // Memory plan algorithm - void PlanMemory(const std::vector& seq) { + void PlanMemory(const std::vector& seq, + const std::unordered_map& alloc_info) { + std::unordered_set inplace_flag; + for (size_t i = 0; i < seq.size(); ++i) { const StmtEntry& s = seq[i]; + auto it = event_map_.find(seq[i].stmt); + + // scope_pair_offset >= 0 means it is either + // - leaf stmt(offset = 0) + // - beginning of scope(offset < 0) + // In both cases, we need to handle the gen event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) { + // Inplace operation detection + // specially handle this + bool detect_inplace = detect_inplace_ && (it->second.gen.size() <= 2); + + for (const Variable* var : it->second.gen) { + CHECK(alloc_info.count(var)); + const AllocEntry& ae = alloc_info.at(var); + StorageEntry* dst_entry = nullptr; + // inplace detection + if (detect_inplace) { + for (const Variable* src : it->second.kill) { + if (!inplace_flag.count(src) && alloc_map_.count(src)) { + InplaceOpVerifier visitor; + StorageEntry* src_entry = alloc_map_.at(src); + if (src_entry->scope == ae.storage_scope && + src_entry->attach_scope_ == thread_scope_ && + src_entry->elem_type == ae.alloc->type.element_of() && + visitor.Check(s.stmt, var, src)) { + uint64_t const_nbits = static_cast( + ae.alloc->constant_allocation_size() * + ae.alloc->type.bits() * + ae.alloc->type.lanes()); + if (src_entry->const_nbits == const_nbits) { + // successfully inplace + dst_entry = src_entry; + inplace_flag.insert(src); + } + } + } + } + } + if (dst_entry == nullptr) { + dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + } + dst_entry->allocs.emplace_back(ae.alloc); + alloc_map_[var] = dst_entry; + } + } + // enter/exit new scope if (s.stmt->is_type()) { const auto* op = static_cast(s.stmt); - CHECK(op->attr_key == attr::thread_extent || - op->attr_key == attr::pragma_scope); - PlanNewScope(op); + if (op->attr_key == attr::thread_extent || + op->attr_key == attr::pragma_scope) { + PlanNewScope(op); + } else { + CHECK(op->attr_key == attr::extern_scope); + } } else if (s.stmt->is_type()) { const auto* op = static_cast(s.stmt); if (op->for_type == ForType::Parallel) { @@ -515,16 +725,17 @@ class StoragePlanRewriter : public IRMutator { PlanNewScope(op); } } - } else if (s.stmt->is_type()) { - const auto* op = static_cast(s.stmt); - StorageEntry* e = this->FindAlloc(op, thread_scope_, s.alloc_scope); - e->allocs.emplace_back(op); - alloc_map_[op->buffer_var.get()] = e; } - // free list - if (free_loc_.count(i)) { - for (const Variable* var : free_loc_.at(i)) { - this->Free(var); + // scope_pair_offset <= 0 means it is either + // - leaf stmt(offset = 0) + // - end of scope(offset < 0) + // In both cases, we need to handle the kill event correctly + if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) { + for (const Variable* var : it->second.kill) { + // skip space which are already replaced by inplace + if (!inplace_flag.count(var)) { + this->Free(var); + } } } } @@ -534,6 +745,7 @@ class StoragePlanRewriter : public IRMutator { const Node* attach_scope, const StorageScope& scope, size_t const_nbits) { + CHECK(op != nullptr); // Re-use not successful, allocate a new buffer. std::unique_ptr entry(new StorageEntry()); entry->attach_scope_ = attach_scope; @@ -544,9 +756,11 @@ class StoragePlanRewriter : public IRMutator { alloc_vec_.emplace_back(std::move(entry)); return e; } + StorageEntry* FindAlloc(const Allocate* op, const Node* attach_scope, const StorageScope& scope) { + CHECK(op != nullptr); // skip plan for local variable, // compiler can do a better job with register allocation. const uint64_t match_range = 16; @@ -603,6 +817,7 @@ class StoragePlanRewriter : public IRMutator { auto it = alloc_map_.find(var); CHECK(it != alloc_map_.end()); StorageEntry* e = it->second; + CHECK_NE(e->allocs.size(), 0U); // Disable sharing of local memory. if (e->scope.rank > 1 || e->allocs[0]->type.is_handle()) return; // disable reuse of small arrays @@ -616,17 +831,18 @@ class StoragePlanRewriter : public IRMutator { } // thread scope. const Node* thread_scope_{nullptr}; + // whether enable inplace detection. + bool detect_inplace_{false}; // Locations of free ops. - std::unordered_map > free_loc_; - // The allocation attach map - std::unordered_map > attach_map_; - // The allocation assign map - std::unordered_map alloc_map_; + std::unordered_map event_map_; // constant size free map. std::multimap const_free_map_; // symbolic free list, for non constant items. std::list sym_free_list_; + // The allocation attach map + std::unordered_map > attach_map_; + // The allocation assign map + std::unordered_map alloc_map_; // The allocations std::vector > alloc_vec_; }; @@ -693,7 +909,7 @@ class VectorAllocRewriter : public IRMutator { Stmt StorageRewrite(Stmt stmt) { - stmt = StoragePlanRewriter().Rewrite(stmt); + stmt = StoragePlanRewriter().Rewrite(stmt, true); return VectorAllocRewriter().Mutate(stmt); } } // namespace ir diff --git a/tests/python/unittest/test_codegen_extern.py b/tests/python/unittest/test_codegen_extern.py index 43736bc46768..1295ed26cce7 100644 --- a/tests/python/unittest/test_codegen_extern.py +++ b/tests/python/unittest/test_codegen_extern.py @@ -15,6 +15,7 @@ def extern_generator(ins, outs): C = tvm.extern(A.shape, [A], extern_generator, name='C') s = tvm.create_schedule(C.op) + print(tvm.lower(s, [A, C], simple_mode=True)) def check_llvm(): if not tvm.module.enabled("llvm"): diff --git a/tests/python/unittest/test_pass_storage_rewrite.py b/tests/python/unittest/test_pass_storage_rewrite.py index 4d2110319d2c..d3f6307f821f 100644 --- a/tests/python/unittest/test_pass_storage_rewrite.py +++ b/tests/python/unittest/test_pass_storage_rewrite.py @@ -19,14 +19,39 @@ def test_storage_share(): stmt = tvm.ir_pass.CanonicalSimplify(stmt) stmt = tvm.ir_pass.Simplify(stmt) stmt = tvm.ir_pass.StorageRewrite(stmt) - # verify only have two allocations. - # verify that the data is folded. + # verify only have one allocations. + # verify inplace folding works + num_alloc = [0] + def verify(n): + if isinstance(n, tvm.stmt.Allocate): + num_alloc[0] += 1 + tvm.ir_pass.PostOrderVisit(stmt, verify) + assert num_alloc[0] == 1 + + +def test_inplace_rule(): + m = 10 + A = tvm.placeholder((m,), name='A') + A0 = tvm.compute((m,), lambda i: A[i], name='A0') + A1 = tvm.compute((m,), lambda i: A[i] + 1, name='A1') + AA = tvm.compute((m,), lambda i: A0[i] + A1[i] + A1[0], name='AA') + B = tvm.compute((m,), lambda i: AA[i] + 1, name='B') + s = tvm.create_schedule(B.op) + bounds = tvm.schedule.InferBound(s) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.schedule.ScheduleOps(s, bounds) + Ab = tvm.decl_buffer(A.shape, A.dtype, name='A') + Bb = tvm.decl_buffer(B.shape, B.dtype, name='B') + stmt = tvm.ir_pass.StorageFlatten(stmt, {A: Ab, B: Bb}, 64) + stmt = tvm.ir_pass.CanonicalSimplify(stmt) + stmt = tvm.ir_pass.Simplify(stmt) + stmt = tvm.ir_pass.StorageRewrite(stmt) + # verify only have one allocations. + # verify inplace folding works num_alloc = [0] def verify(n): if isinstance(n, tvm.stmt.Allocate): num_alloc[0] += 1 - elif isinstance(n, tvm.stmt.Store): - assert n.buffer_var != n.value.a.buffer_var tvm.ir_pass.PostOrderVisit(stmt, verify) assert num_alloc[0] == 2 @@ -38,7 +63,7 @@ def test_storage_combine(): B = A stages = [] for t in range(num_stage): - B = tvm.compute((n, ), lambda i: B[i] + (t+1), name='A%d' % t) + B = tvm.compute((n, ), lambda i: B[i] + B[0] + (t+1), name='A%d' % t) stages.append(B) s = tvm.create_schedule(B.op) @@ -121,12 +146,14 @@ def test_parallel_alloc(): A[j] = A[j] + 2 body = ib.get() body = tvm.ir_pass.StorageRewrite(body) + assert(isinstance(body.body.body.body.body, tvm.stmt.Allocate)) if __name__ == "__main__": + test_inplace_rule() + test_storage_share() test_parallel_alloc() test_storage_combine() test_storage_share_gpu() - test_storage_share()