From a84db6958a7b8d754a748af8fad6ff27196e1cc4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:24:41 +0900 Subject: [PATCH 01/90] Add storage scope to ProducerRealize, always create a buffer with scope --- include/tvm/te/operation.h | 15 ++++++------ include/tvm/tir/buffer.h | 10 +++++++- include/tvm/tir/stmt.h | 9 ++++++-- python/tvm/tir/ir_builder.py | 2 +- python/tvm/tir/stmt.py | 7 ++++-- src/runtime/thread_storage_scope.h | 4 +++- src/te/operation/compute_op.cc | 4 ++-- src/te/operation/extern_op.cc | 4 ++-- src/te/operation/hybrid_op.cc | 4 ++-- src/te/operation/placeholder_op.cc | 2 +- src/te/operation/scan_op.cc | 4 ++-- src/tir/ir/buffer.cc | 12 ++++++++-- src/tir/ir/stmt.cc | 7 +++--- src/tir/transforms/thread_storage_sync.cc | 28 ++++++++--------------- 14 files changed, 66 insertions(+), 46 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..13f39317dbe4 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object { * \param stage the op's stage. * \param realize_map The realization domain map of the operators. * \param body The body that is going to get + * \param storage_scope The storage scope associated with this realization * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 017f4f7052b1..ed5718d8f358 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -191,12 +191,20 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param storage_scope The storage scope associated with this buffer * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); + +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +TVM_DLL String GetStorageScope(Var buffer_var); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..9997a4d95694 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -464,18 +464,22 @@ class ProducerRealizeNode : public StmtNode { PrimExpr condition; /*! \brief The body of realization. */ Stmt body; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4934bf04727f..5aae068f4d58 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index dd7665a56692..94074b906777 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -364,13 +364,16 @@ class ProducerRealize(Stmt): body : Stmt The realize body + storage_scope : str + The storage scope associated with this realization + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body, span=None): + def __init__(self, producer, bounds, condition, body, storage_scope="", span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body, span + _ffi_api.ProducerRealize, producer, bounds, condition, body, storage_scope, span ) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c0393600b60c..d93a1f130bae 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -118,7 +118,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s == "") { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..26c08955f5ad 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..851d440a6378 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,12 +45,20 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } +String GetStorageScope(Var buffer_var) { + auto type = buffer_var->type_annotation; + const auto* ptr_type = type.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..6fdeb30ec100 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -377,7 +377,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +394,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + String storage_scope, Span span) { + return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..896224c0e956 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -223,14 +224,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +251,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +261,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +283,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +331,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier From 014063122547fd2ab77a3cec0d998b2a5ee68038 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:26:38 +0900 Subject: [PATCH 02/90] update schedule_ops.cc --- src/te/schedule/schedule_ops.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..f130e1fb93e4 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,7 +51,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); + pipeline = s->op->BuildRealize(s, dom_map, pipeline, s->scope); // use attribute to mark scope of the operation. pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); @@ -175,8 +175,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +217,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = + ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); From b6d8e6c1a9ed775624179fd66cc6b9a2a6a4ad8f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:31:11 +0900 Subject: [PATCH 03/90] update schedule_postproc_to_primfunc.cc --- src/te/schedule/schedule_postproc_to_primfunc.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..8e6cc131b76e 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -95,7 +95,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +122,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } From bebcc5017d35d9bd8f79cde2af1f3e12c8615e03 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:41:12 +0900 Subject: [PATCH 04/90] restore more realize_scope This reverts commit b66c3baa54feeb8e34016713a1be21802b3296bf. --- src/te/schedule/schedule_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index f130e1fb93e4..21edd2f94b20 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -175,7 +175,8 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { From edeaed2821d929bfee59f1b9141e8b0380b86ae1 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:21:31 +0900 Subject: [PATCH 05/90] make the default scope be "" instead of None in ir builder --- python/tvm/tir/ir_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 5aae068f4d58..484d00f9611a 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -394,7 +394,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=None): + def allocate(self, dtype, shape, name="buf", scope=""): """Create a allocate statement. Parameters From cd6167e9238e6e0837ffd1952d46d2e8becef767 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:29:31 +0900 Subject: [PATCH 06/90] restore realize_scope visit in storage_flatten.cc --- src/tir/transforms/storage_flatten.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..fab007c5e4d3 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -79,7 +79,6 @@ class StorageFlattener : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { @@ -156,10 +155,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -491,8 +488,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. From a33fb0d193bda91cc96d4ee68377668e2b4e6b7b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:30:11 +0900 Subject: [PATCH 07/90] update storage_access.cc --- src/tir/transforms/storage_access.cc | 20 ++++++-------------- src/tir/transforms/storage_access.h | 4 +--- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..8f5b8d75c1d4 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -208,7 +204,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +240,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir From e878eae17e2a4d0db02378e5c02035da743a7192 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:58:19 +0900 Subject: [PATCH 08/90] make sure buffer var is of PointerType in ir builder This reverts commit e650b6c24cabd52a073064e51c2e4fee816e88fd. --- python/tvm/tir/ir_builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 484d00f9611a..1573d96e7d0d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -424,7 +424,7 @@ def allocate(self, dtype, shape, name="buf", scope=""): self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) - def pointer(self, content_type, name="ptr"): + def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. Parameters @@ -435,12 +435,15 @@ def pointer(self, content_type, name="ptr"): name : str, optional The name of the pointer. + scope : str, optional + The scope of the buffer. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) return BufferVar(self, buffer_var, None, content_type) def buffer_ptr(self, buf, shape=None): From 9f98deaba5aefecaa57c7d1ec1a32d40d21369c4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 14:14:38 +0900 Subject: [PATCH 09/90] enforce default storage scope of global --- include/tvm/ir/type.h | 2 +- include/tvm/tir/buffer.h | 3 ++- python/tvm/tir/buffer.py | 4 ++-- src/ir/type.cc | 3 ++- src/te/operation/cross_thread_reduction.cc | 4 ++-- src/te/schedule/schedule_postproc_to_primfunc.cc | 7 ++++--- src/tir/ir/buffer.cc | 8 ++++++++ src/tir/ir/stmt.cc | 6 ++++++ src/tir/transforms/lower_thread_allreduce.cc | 12 ++++++------ src/tir/transforms/storage_flatten.cc | 4 ++++ 10 files changed, 37 insertions(+), 16 deletions(-) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c772650809fa..2c6e0c35a280 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -184,7 +184,7 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); + TVM_DLL explicit PointerType(Type element_type, String storage_scope = "global"); TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index ed5718d8f358..c66fa73d8096 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", Span span = Span()); + String name = "buffer", String storage_scope = "global", Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. @@ -205,6 +205,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \return A string representing the storage scope of this buffer variable. */ TVM_DLL String GetStorageScope(Var buffer_var); +TVM_DLL Var UpdateStorageScope(Var buffer_var, String storage_scope); /*! * \brief Base node for data producers. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d905a53b3303..9c78f8511903 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -140,7 +140,7 @@ def decl_buffer( data=None, strides=None, elem_offset=None, - scope="", + scope="global", data_alignment=-1, offset_factor=0, buffer_type="", @@ -250,7 +250,7 @@ def decl_buffer( # Bool is represented as uint1 in the IR, but stored as int8 storage_type = PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type - data = Var(name, PointerType(storage_type), span) + data = Var(name, PointerType(storage_type, scope), span) return _ffi_api.Buffer( data, dtype, diff --git a/src/ir/type.cc b/src/ir/type.cc index fe8e00329bbc..3f450cdf0392 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -44,6 +44,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); PointerType::PointerType(Type element_type, String storage_scope) { + ICHECK(storage_scope != ""); ObjectPtr n = make_object(); n->element_type = std::move(element_type); n->storage_scope = std::move(storage_scope); @@ -53,7 +54,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); TVM_REGISTER_GLOBAL("ir.PointerType") - .set_body_typed([](Type element_type, String storage_scope = "") { + .set_body_typed([](Type element_type, String storage_scope = "global") { return PointerType(element_type, storage_scope); }); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index da20dd875ba5..a6ee10edd5a3 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t))); + PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; @@ -177,7 +177,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 8e6cc131b76e..e9caeabcabd0 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,7 +49,7 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "global") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); @@ -122,11 +122,12 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "global") { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "global", + bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 851d440a6378..704afed689cd 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,6 +48,7 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + if (storage_scope == "") storage_scope = "global"; return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } @@ -59,6 +60,13 @@ String GetStorageScope(Var buffer_var) { return ptr_type->storage_scope; } +Var UpdateStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6fdeb30ec100..08d8e15dd2b7 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,6 +61,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + CHECK(buf); + CHECK(value.as()->value == GetStorageScope(GetRef(buf))) + << value.as()->value << ", " << GetStorageScope(GetRef(buf)); + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9e536814fa12..4a1b31fb8dd4 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -85,13 +85,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + stmt = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, + repl->condition, op->body); } else { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = Allocate(UpdateStorageScope(repl->buffer_var, "shared"), repl->dtype, repl->extents, + repl->condition, stmt); } return stmt; } else { @@ -365,8 +365,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + body = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, + repl->condition, body); } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index fab007c5e4d3..57fae5069e3a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -197,6 +197,7 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } + LOG(INFO) << "skey: " << skey.to_string(); e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); @@ -225,6 +226,9 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } + CHECK(e.buffer->scope == GetStorageScope(e.buffer->data)) + << e.buffer->scope << ", " << GetStorageScope(e.buffer->data) << ", " + << GetStorageScope(op->buffer->data); ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { From 2923f4d96a332bf6ff169311f22f6592724d2143 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 07:03:53 +0900 Subject: [PATCH 10/90] added remap pass but does not work yet --- src/tir/transforms/lower_thread_allreduce.cc | 58 +++++++++++++++++--- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 4a1b31fb8dd4..a9f7671b519e 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -37,6 +37,44 @@ namespace tvm { namespace tir { +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_var_remap) + : new_var_remap_(new_var_remap) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = new_var_remap_.find(op); + LOG(INFO) << "Visit " << op->name_hint; + if (it == new_var_remap_.end()) { + return GetRef(op); + } + LOG(INFO) << "Remapped " << op->name_hint; + return it->second; + } + + Stmt VisitStmt_(const AllocateNode* op) final { + LOG(INFO) << "Visit alloc node with buffer " << op->buffer_var->name_hint; + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + auto body = StmtExprMutator::VisitStmt(op->body); + return Allocate(Downcast(remapped), op->dtype, op->extents, op->condition, body); + } + + Stmt VisitStmt_(const StoreNode* op) final { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); + } + + private: + std::unordered_map new_var_remap_; +}; + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -85,13 +123,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, - repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } else { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(UpdateStorageScope(repl->buffer_var, "shared"), repl->dtype, repl->extents, - repl->condition, stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; } else { @@ -108,6 +147,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + std::unordered_map new_var_remap_; + private: // Thread entry struct ThreadEntry { @@ -365,8 +406,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, - repl->condition, body); + body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } @@ -590,7 +632,9 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); - n->body = ThreadAllreduceBuilder(target_node)(n->body); + ThreadAllreduceBuilder thread_all_reduce(target_node); + auto reduce_body = thread_all_reduce(n->body); + n->body = RemapStorageScope(thread_all_reduce.new_var_remap_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); From 6a793545f42c3922dc244028a1348b67d5edf822 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:40:01 +0900 Subject: [PATCH 11/90] fixed all reduce issue This reverts commit 8e20003c5325085ed22ee57180aca18644b3b5ab. --- src/tir/transforms/lower_thread_allreduce.cc | 12 +++++------- src/tir/transforms/storage_flatten.cc | 16 ++++++---------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index a9f7671b519e..5a2bd8ac94c1 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -44,19 +44,17 @@ class RemapStorageScope final : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { auto it = new_var_remap_.find(op); - LOG(INFO) << "Visit " << op->name_hint; if (it == new_var_remap_.end()) { return GetRef(op); } - LOG(INFO) << "Remapped " << op->name_hint; return it->second; } Stmt VisitStmt_(const AllocateNode* op) final { - LOG(INFO) << "Visit alloc node with buffer " << op->buffer_var->name_hint; - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); auto body = StmtExprMutator::VisitStmt(op->body); - return Allocate(Downcast(remapped), op->dtype, op->extents, op->condition, body); + auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); } Stmt VisitStmt_(const StoreNode* op) final { @@ -129,7 +127,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; + LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; @@ -407,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; + LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 57fae5069e3a..91d5792df97a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,10 +157,8 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if (strkey.length() == 0) { - if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); - } + if(curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } else { skey = StorageScope::Create(strkey); } @@ -197,9 +195,10 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - LOG(INFO) << "skey: " << skey.to_string(); - e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + auto* ptr_type = op->buffer->data->type_annotation.as(); + ICHECK(ptr_type); + auto new_var = Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; @@ -226,9 +225,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - CHECK(e.buffer->scope == GetStorageScope(e.buffer->data)) - << e.buffer->scope << ", " << GetStorageScope(e.buffer->data) << ", " - << GetStorageScope(op->buffer->data); ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { From a566af98fd5f8e9d055b7fb543446ca66b34725d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:48:11 +0900 Subject: [PATCH 12/90] simplify --- src/tir/transforms/lower_thread_allreduce.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 5a2bd8ac94c1..8cb7db902db6 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -52,8 +52,8 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto body = StmtExprMutator::VisitStmt(op->body); - auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, body); + auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); } @@ -127,7 +127,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; @@ -405,7 +404,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } From 345573a44c9cf77907e0b7fc6aa344365cba4349 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:50:13 +0900 Subject: [PATCH 13/90] trying mitigation for aot test --- src/tir/transforms/storage_access.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 8f5b8d75c1d4..952758c4e5a7 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -241,7 +241,10 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { - return StorageScope::Create(GetStorageScope(buffer_var)); + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetStorageScope(buffer_var)); + } + return StorageScope(); // global by default } } // namespace tir From be2ab963e132500d2ef178526dee5117df6207df Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:00:52 +0900 Subject: [PATCH 14/90] merge remaining changes from initial branch --- src/target/spirv/codegen_spirv.cc | 14 +++++----- src/target/spirv/codegen_spirv.h | 2 -- src/tir/transforms/storage_rewrite.cc | 37 +++++++++------------------ 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d52bee44e98..7c9dfcaf95e0 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,6 +23,7 @@ */ #include "codegen_spirv.h" +#include #include #include #include @@ -644,13 +645,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + spirv::Value buf; - StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (info.scope.rank == runtime::StorageRank::kShared) { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); @@ -660,8 +662,10 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { builder_->SetName(buf, op->buffer_var->name_hint); + StorageInfo& info = storage_info_[op->buffer_var.get()]; ICHECK(!info.content_fixed); info.UpdateContentType(op->dtype); + ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -677,10 +681,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3868322a74e0..a44dc5fd3d34 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -116,8 +116,6 @@ class CodeGenSPIRV : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief Whether it is volatile */ bool is_volatile{false}; /*! \brief Whether it is volatile */ diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c755576e2b88..d39d1f2ed901 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -75,8 +75,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +84,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +173,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -409,10 +398,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -716,7 +703,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +714,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +731,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } From f83cba0d464c88cc550ab1e524cb3200391b1442 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:11:50 +0900 Subject: [PATCH 15/90] remove use of attr::storage_scope from codegen --- src/target/llvm/codegen_amdgpu.cc | 27 ++++++++-------- src/target/llvm/codegen_cpu.cc | 6 ++-- src/target/llvm/codegen_llvm.cc | 53 ++++++++++++++----------------- src/target/llvm/codegen_llvm.h | 11 ++----- src/target/llvm/codegen_nvptx.cc | 26 +++++++-------- src/target/source/codegen_cuda.cc | 8 ++--- 6 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 78f8a50e4e1b..01d2b2f7ad4d 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -76,30 +76,31 @@ class CodeGenAMDGPU : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; @@ -108,11 +109,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalVariable* global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + if (global->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(alignment); #endif } buf = global; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ab96d6e69d14..b9761355b208 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -463,9 +463,9 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } // Add alignment attribute if needed. #if TVM_LLVM_VERSION >= 50 - auto f = alloc_storage_info_.find(var.get()); - if (f != alloc_storage_info_.end()) { - unsigned align = f->second.alignment; + auto f = alloc_storage_alignment_.find(var.get()); + if (f != alloc_storage_alignment_.end()) { + unsigned align = f->second; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48ccefafe3c4..545c94dddae3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -104,7 +104,7 @@ void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - alloc_storage_info_.clear(); + alloc_storage_alignment_.clear(); volatile_buf_.clear(); analyzer_.reset(new arith::Analyzer()); } @@ -165,9 +165,9 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { #if TVM_LLVM_VERSION >= 50 for (size_t i = 0; i < f->params.size(); ++i) { const Var& var = f->params[i]; - auto f = alloc_storage_info_.find(var.get()); - if (f != alloc_storage_info_.end()) { - unsigned align = f->second.alignment; + auto f = alloc_storage_alignment_.find(var.get()); + if (f != alloc_storage_alignment_.end()) { + unsigned align = f->second; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); @@ -498,11 +498,12 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); - auto it = alloc_storage_info_.find(buf_var); - if (it != alloc_storage_info_.end()) { - const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); - max_align_bits = info.alignment * 8; + auto it = alloc_storage_alignment_.find(buf_var); + if (it != alloc_storage_alignment_.end()) { + const int alignment = it->second; + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); + max_align_bits = alignment * 8; } else { *p_native_bits = native_vector_bits_; } @@ -1353,25 +1354,25 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } - info.alignment = alloca->getAlignment(); + alignment = alloca->getAlignment(); buf = alloca; buf = builder_->CreatePointerCast( @@ -1390,18 +1391,13 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); - alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); - if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { + alloc_storage_alignment_[v] = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_alignment_[v] > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_info_[v].alignment); + alloc_storage_alignment_[v]); } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); @@ -1426,9 +1422,8 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { - builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_info_[v].alignment); + if (alloc_storage_alignment_.count(v) && alloc_storage_alignment_[v] > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), alloc_storage_alignment_[v]); } this->VisitStmt(op->body); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d5fcfab6d889..fb13ce42f897 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -161,13 +161,6 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: - /*! \brief The storage information */ - struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; - /*! \brief The alignment of allocation */ - int alignment{0}; - }; /*! * \brief Execute falloca at the beginning of the * currrent function and obtain its return value. @@ -327,8 +320,8 @@ class CodeGenLLVM : public ExprFunctor, std::vector > link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; - /*! \brief the storage scope of allocation */ - std::unordered_map alloc_storage_info_; + /*! \brief the alignment of allocation */ + std::unordered_map alloc_storage_alignment_; // The definition of local variable. std::unordered_map var_map_; // global strings diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 9e56529ec9ef..e8ae088ece32 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -51,31 +51,31 @@ class CodeGenNVPTX : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } - - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; @@ -85,9 +85,9 @@ class CodeGenNVPTX : public CodeGenLLVM { *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(alignment); #endif buf = global; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71..66b401c731e3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,12 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; + std::string scope = GetStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,6 +719,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } + const VarNode* buffer = op->buffer_var.as(); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { From f1b0f3cbf5d6911c7f44dd642239d4c95ada66b7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:12:19 +0900 Subject: [PATCH 16/90] restore a visit to AttrStmt with attr::storage_scope in storage_rewrite --- src/tir/transforms/storage_rewrite.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d39d1f2ed901..e99b0c0e1086 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -398,8 +398,10 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::storage_scope) { + return this->VisitStmt(op->body); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; From 30edcd66b874699fc1d9343612523c541a0f0e22 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:21:51 +0900 Subject: [PATCH 17/90] disable check --- src/ir/type.cc | 1 - src/tir/ir/stmt.cc | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ir/type.cc b/src/ir/type.cc index 3f450cdf0392..567e31d9c2a6 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -44,7 +44,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); PointerType::PointerType(Type element_type, String storage_scope) { - ICHECK(storage_scope != ""); ObjectPtr n = make_object(); n->element_type = std::move(element_type); n->storage_scope = std::move(storage_scope); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 08d8e15dd2b7..18a5c2691005 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,12 +61,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - if (attr_key == attr::storage_scope) { - const VarNode* buf = node.as(); - CHECK(buf); - CHECK(value.as()->value == GetStorageScope(GetRef(buf))) - << value.as()->value << ", " << GetStorageScope(GetRef(buf)); - } + // TODO(masahi): Enable this invariant check + // if (attr_key == attr::storage_scope) { + // const VarNode* buf = node.as(); + // ICHECK(buf); + // ICHECK(value.as()->value == GetStorageScope(GetRef(buf))) + // << value.as()->value << ", " << GetStorageScope(GetRef(buf)); + // } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); From 9560f1b2735eeb9aa2ff7cd7026d60b2fa06f78b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:26:23 +0900 Subject: [PATCH 18/90] lint fix --- include/tvm/tir/buffer.h | 3 ++- src/te/operation/cross_thread_reduction.cc | 4 +++- src/tir/transforms/storage_flatten.cc | 5 +++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index c66fa73d8096..bf5c1ceaaf5c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,8 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "global", Span span = Span()); + String name = "buffer", String storage_scope = "global", + Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index a6ee10edd5a3..0c20328f02b7 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -177,7 +178,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); + res_handles[idx] = + Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 91d5792df97a..7af36d164a56 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,7 +157,7 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if(curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { + if (curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } else { skey = StorageScope::Create(strkey); @@ -197,7 +197,8 @@ class StorageFlattener : public StmtExprMutator { auto* ptr_type = op->buffer->data->type_annotation.as(); ICHECK(ptr_type); - auto new_var = Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + auto new_var = + Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); From 3de3edd671d4c98ff4b81b9b0ee2d55317e3fbc9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:52:18 +0900 Subject: [PATCH 19/90] revert default scope to "" --- include/tvm/ir/type.h | 2 +- include/tvm/tir/buffer.h | 2 +- python/tvm/tir/buffer.py | 2 +- src/ir/type.cc | 2 +- src/te/schedule/schedule_postproc_to_primfunc.cc | 6 +++--- src/tir/ir/buffer.cc | 1 - src/tir/transforms/storage_flatten.cc | 7 ++++--- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 2c6e0c35a280..c772650809fa 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -184,7 +184,7 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = "global"); + TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index bf5c1ceaaf5c..fd6718a44e4b 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "global", + String name = "buffer", String storage_scope = "", Span span = Span()); /*! diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 9c78f8511903..eb48e4c8068b 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -140,7 +140,7 @@ def decl_buffer( data=None, strides=None, elem_offset=None, - scope="global", + scope="", data_alignment=-1, offset_factor=0, buffer_type="", diff --git a/src/ir/type.cc b/src/ir/type.cc index 567e31d9c2a6..fe8e00329bbc 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -53,7 +53,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); TVM_REGISTER_GLOBAL("ir.PointerType") - .set_body_typed([](Type element_type, String storage_scope = "global") { + .set_body_typed([](Type element_type, String storage_scope = "") { return PointerType(element_type, storage_scope); }); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index e9caeabcabd0..b80f76f5acbf 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,7 +49,7 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "global") { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); @@ -122,11 +122,11 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "global") { + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "global", + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 704afed689cd..49da7c7f5630 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,7 +48,6 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - if (storage_scope == "") storage_scope = "global"; return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 7af36d164a56..eca3bba83583 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,12 +157,13 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if (curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + if (strkey.length() == 0) { + if (curr_thread_scope_.size() != 0) { + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + } } else { skey = StorageScope::Create(strkey); } - // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); From b703d8b883b55c138a2fa51599eb32717c43d5e2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:56:42 +0900 Subject: [PATCH 20/90] format --- include/tvm/tir/buffer.h | 3 +-- src/te/schedule/schedule_postproc_to_primfunc.cc | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index fd6718a44e4b..f01158967bdd 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,8 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", - Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index b80f76f5acbf..8e6cc131b76e 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -126,8 +126,7 @@ class TensorToBufferMapper : public StmtExprMutator { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", - bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; From b7abe5a814034148a5a26c8b5799cc9350b63f49 Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 4 Jul 2021 10:58:25 +0900 Subject: [PATCH 21/90] fix volatile access to shared mem in lower all reduce --- src/tir/transforms/lower_thread_allreduce.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 8cb7db902db6..68bf24abb847 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -52,9 +52,17 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, - StmtExprMutator::VisitStmt(op->body)); - return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); + auto new_scope = GetStorageScope(remapped); + if (new_scope != GetStorageScope(op->buffer_var)) { + Stmt body = StmtExprMutator::VisitStmt(op->body); + if (new_scope == "shared") { + // use volatile access to shared buffer. + body = AttrStmt(remapped, attr::volatile_scope, 1, body); + } + body = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); + } + return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const StoreNode* op) final { @@ -124,9 +132,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } else { - // use volatile access to shared buffer. - stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; From 62f818cce066b13915d0f4aa1b49b658de76bbae Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 4 Jul 2021 11:48:31 +0900 Subject: [PATCH 22/90] fixed gpu coorporative load/store test --- Jenkinsfile | 1 - src/tir/transforms/storage_rewrite.cc | 3 ++- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/Jenkinsfile b/Jenkinsfile index 815c07ad8806..f26b148085fb 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -282,7 +282,6 @@ stage('Unit Test') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_arm} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_unittest.sh" - sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh" junit "build/pytest-results/*.xml" // sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh" } diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index e99b0c0e1086..a18bc84604aa 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -922,7 +922,8 @@ class VectorAllocRewriter : public StmtExprMutator { extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); // create a new buffer var DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); + Var new_buffer_var(op->buffer_var->name_hint, + PointerType(PrimType(new_dtype), GetStorageScope(op->buffer_var))); // update the remap req. var_remap_.Set(op->buffer_var, new_buffer_var); return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); From d41bdc816f83c810083ddb10721f31f2262b4c88 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 06:58:44 +0900 Subject: [PATCH 23/90] pass storage scope to PointerType in tvm script parser This reverts commit 99cfb9d18781dcfdea169d920450f9063ab18b6b. --- Jenkinsfile | 1 + python/tvm/script/scope_handler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index f26b148085fb..815c07ad8806 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -282,6 +282,7 @@ stage('Unit Test') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_arm} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_unittest.sh" + sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh" junit "build/pytest-results/*.xml" // sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh" } diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index a23401d926e9..d07209485bd4 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -140,7 +140,7 @@ def enter_scope( def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) From 594d34b2db352bc301428f5aac6e282ecec2180c Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 15:16:17 +0900 Subject: [PATCH 24/90] fixed tvmscript roundtrip test --- python/tvm/script/special_stmt.py | 16 ++++++++++++++++ src/printer/tvmscript_printer.cc | 14 ++++++++++++-- .../python/unittest/test_tvmscript_roundtrip.py | 4 ++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 7eb938c58f96..befa37e19252 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -491,6 +491,22 @@ def var(dtype, span): super().__init__(var, def_symbol=True) +@register +class BufferVarDef(SpecialStmt): + """Special function for defining a Var""" + + def __init__(self): + def buffer_var(dtype, storage_scope, span): + assert isinstance( + self.node, ast.Assign + ), f"VarDef expected ast.Assign but got {type(self.node)}" + ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + v = te.var(self.node.lhs.id.name, ptr_type, span=span) + self.context.update_symbol(v.name, v, self.node) + + super().__init__(buffer_var, def_symbol=True) + + @register class EnvThread(SpecialStmt): """Bind a var to thread env""" diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4bbe17064c87..e855712617ca 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1013,8 +1014,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; - header_var << PrintDType(var->dtype) << ")"; + auto type = GetRef(var)->type_annotation; + if (auto* ptr_type = type.as()) { + auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << PrintDType(prim_type->dtype) << ", " + << Doc::StrLiteral(ptr_type->storage_scope) << ")"; + } else { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } } } doc << Doc::Indent(4, header_attr << header_var << header_buf << body); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 164949552859..6c0e228e8e4c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -277,8 +277,8 @@ def mmult( } ) # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") + C_global = tir.buffer_var("float32", "global") + packedB = tir.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") From 151eb0281e96035e4f0d99b9e0ae691afa72087b Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 15:37:19 +0900 Subject: [PATCH 25/90] fixed tir flatten buffer test --- python/tvm/script/special_stmt.py | 4 ++-- .../python/unittest/test_tir_transform_flatten_buffer.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index befa37e19252..6dbbb7354b26 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -493,13 +493,13 @@ def var(dtype, span): @register class BufferVarDef(SpecialStmt): - """Special function for defining a Var""" + """Special function for defining a Var of pointer type""" def __init__(self): def buffer_var(dtype, storage_scope, span): assert isinstance( self.node, ast.Assign - ), f"VarDef expected ast.Assign but got {type(self.node)}" + ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) v = te.var(self.node.lhs.id.name, ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c997748649cd..6929a329ac0f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -35,7 +35,7 @@ def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32") + B = tir.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with tir.block() as []: tir.reads(A[i, j]) @@ -111,7 +111,7 @@ def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32 with tir.block([]): tir.reads(A[i, m]) tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32") + B = tir.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): with tir.block([]) as []: tir.reads(A[i, j]) @@ -190,8 +190,8 @@ def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) - B = tir.alloc_buffer((32,)) - C = tir.alloc_buffer((32,)) + B = tir.alloc_buffer((32,), scope="global") + C = tir.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 From 1e883353a080a5cdc28c4e1374a2686b09261871 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:00:05 +0900 Subject: [PATCH 26/90] fixed test_tir_transform_hoist_if.py --- tests/python/unittest/test_tir_transform_hoist_if.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 252a187dbdc5..b111e2be75c7 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -636,7 +636,7 @@ def test_hoisting_block_scope_4(): def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") + data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") From 7e8a7c1c3ac90a9ca2dfbaa9e247dded1e13f556 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:00:28 +0900 Subject: [PATCH 27/90] use storage scope global by default in aot_executor_codegen.cc --- src/relay/backend/aot_executor_codegen.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9b495adbdea8..9b613bbff99f 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -722,7 +722,8 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { - te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + te::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global")); sids_table_[sid] = buffer_var; } } From d458187956788c4c3020f050c787864f7195140b Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:33:56 +0900 Subject: [PATCH 28/90] add missing default storage scope in create_primfunc.cc --- src/te/operation/create_primfunc.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 190892b2283f..a47556bac101 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -109,7 +109,7 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: } // Step 2. Declare buffer and update op2buffers - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint()); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; // Step 3. Add Buffer to root_alloc @@ -270,7 +270,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { const te::Tensor& tensor = op.output(0); // Check op is in op list ICHECK(info.IsArg(tensor)); - const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + const Buffer& buffer = + decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); info.tensor2buffers[tensor] = buffer; } else if (const auto* compute_op = op.as()) { // Case 2. ComputeOp (te.compute) From 410fd4525a7c6f247e0b725b43dbca53b21839b7 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 06:17:00 +0900 Subject: [PATCH 29/90] restore StorageInfo struct in llvm backend --- src/target/llvm/codegen_amdgpu.cc | 22 +++++++-------- src/target/llvm/codegen_cpu.cc | 6 ++--- src/target/llvm/codegen_llvm.cc | 45 ++++++++++++++++--------------- src/target/llvm/codegen_llvm.h | 9 +++++-- src/target/llvm/codegen_nvptx.cc | 20 +++++++------- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 01d2b2f7ad4d..4b182e17ed74 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -76,13 +76,13 @@ class CodeGenAMDGPU : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { @@ -91,11 +91,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } buf = alloca; @@ -109,11 +109,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalVariable* global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(alignment)) { + if (global->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(alignment); + global->setAlignment(info.alignment); #endif } buf = global; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index b9761355b208..ab96d6e69d14 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -463,9 +463,9 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } // Add alignment attribute if needed. #if TVM_LLVM_VERSION >= 50 - auto f = alloc_storage_alignment_.find(var.get()); - if (f != alloc_storage_alignment_.end()) { - unsigned align = f->second; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 545c94dddae3..9701a9f9ebb0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -104,7 +104,7 @@ void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - alloc_storage_alignment_.clear(); + alloc_storage_info_.clear(); volatile_buf_.clear(); analyzer_.reset(new arith::Analyzer()); } @@ -165,9 +165,9 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { #if TVM_LLVM_VERSION >= 50 for (size_t i = 0; i < f->params.size(); ++i) { const Var& var = f->params[i]; - auto f = alloc_storage_alignment_.find(var.get()); - if (f != alloc_storage_alignment_.end()) { - unsigned align = f->second; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); @@ -498,12 +498,12 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); - auto it = alloc_storage_alignment_.find(buf_var); - if (it != alloc_storage_alignment_.end()) { - const int alignment = it->second; + auto it = alloc_storage_info_.find(buf_var); + if (it != alloc_storage_info_.end()) { + const StorageInfo& info = it->second; *p_native_bits = NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); - max_align_bits = alignment * 8; + max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; } @@ -1354,25 +1354,25 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } - alignment = alloca->getAlignment(); + info.alignment = alloca->getAlignment(); buf = alloca; buf = builder_->CreatePointerCast( @@ -1394,10 +1394,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); - alloc_storage_alignment_[v] = static_cast(op->value.as()->value); - if (var_map_.count(v) && alloc_storage_alignment_[v] > 1) { + alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_alignment_[v]); + alloc_storage_info_[v].alignment); } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); @@ -1422,8 +1422,9 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - if (alloc_storage_alignment_.count(v) && alloc_storage_alignment_[v] > 1) { - builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), alloc_storage_alignment_[v]); + if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); } this->VisitStmt(op->body); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index fb13ce42f897..810e59be7214 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -161,6 +161,11 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: + /*! \brief The storage information */ + struct StorageInfo { + /*! \brief The alignment of allocation */ + int alignment{0}; + }; /*! * \brief Execute falloca at the beginning of the * currrent function and obtain its return value. @@ -320,8 +325,8 @@ class CodeGenLLVM : public ExprFunctor, std::vector > link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; - /*! \brief the alignment of allocation */ - std::unordered_map alloc_storage_alignment_; + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_info_; // The definition of local variable. std::unordered_map var_map_; // global strings diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index e8ae088ece32..18faf34143f0 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -51,13 +51,13 @@ class CodeGenNVPTX : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { @@ -66,11 +66,11 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } buf = alloca; @@ -85,9 +85,9 @@ class CodeGenNVPTX : public CodeGenLLVM { *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(alignment); + global->setAlignment(info.alignment); #endif buf = global; } From 742c243333442cf7eb9e3f35b4dde563ee7c0a88 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 06:43:04 +0900 Subject: [PATCH 30/90] UpdateStorageScope -> WithStorageScope --- include/tvm/tir/buffer.h | 2 +- src/tir/ir/buffer.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f01158967bdd..d1e1c5b5103a 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -205,7 +205,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \return A string representing the storage scope of this buffer variable. */ TVM_DLL String GetStorageScope(Var buffer_var); -TVM_DLL Var UpdateStorageScope(Var buffer_var, String storage_scope); +TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); /*! * \brief Base node for data producers. diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 49da7c7f5630..670e503d6114 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -59,7 +59,7 @@ String GetStorageScope(Var buffer_var) { return ptr_type->storage_scope; } -Var UpdateStorageScope(Var buffer_var, String storage_scope) { +Var WithStorageScope(Var buffer_var, String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 68bf24abb847..17f07012c6f3 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -130,10 +130,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); } else { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "shared"); } return stmt; } else { @@ -410,7 +410,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); } } From aa90d42fdadbc21064cdd00cb613eb5c25069f81 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:07:06 +0900 Subject: [PATCH 31/90] fixed lower warp memory test --- src/tir/transforms/lower_thread_allreduce.cc | 4 +++ src/tir/transforms/lower_warp_memory.cc | 38 +++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 17f07012c6f3..9488d1e99fcb 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -37,6 +37,8 @@ namespace tvm { namespace tir { +namespace { + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) @@ -81,6 +83,8 @@ class RemapStorageScope final : public StmtExprMutator { std::unordered_map new_var_remap_; }; +} // namespace + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936ca..8f7382a21153 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -44,6 +44,34 @@ namespace tvm { namespace tir { +namespace { + +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_var_remap) + : new_var_remap_(new_var_remap) {} + + Stmt VisitStmt_(const AttrStmtNode* op) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto it = new_var_remap_.find(buf); + if (it != new_var_remap_.end()) { + auto remapped = it->second; + auto new_scope = GetStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + } + return StmtMutator::VisitStmt_(op); + } + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace + // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -356,6 +384,8 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } + std::unordered_map new_var_remap_; + private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); @@ -374,9 +404,7 @@ class WarpMemoryRewriter : private StmtMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); + new_var_remap_[buf] = WithStorageScope(GetRef(buf), "local"); } } return StmtMutator::VisitStmt_(op); @@ -397,7 +425,9 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); - n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); + WarpMemoryRewriter warp_memory_rewriter(warp_size); + auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); + n->body = RemapStorageScope(warp_memory_rewriter.new_var_remap_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); From 1e9a7a32796bfe8ed71bf2107afee1a4f33f8493 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:33:29 +0900 Subject: [PATCH 32/90] GetStorageScope -> GetPtrStorageScope --- include/tvm/tir/buffer.h | 2 +- src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/source/codegen_cuda.cc | 2 +- src/target/spirv/codegen_spirv.cc | 2 +- src/tir/ir/buffer.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 4 ++-- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/storage_access.cc | 2 +- src/tir/transforms/storage_flatten.cc | 2 +- src/tir/transforms/storage_rewrite.cc | 4 ++-- src/tir/transforms/thread_storage_sync.cc | 2 +- 13 files changed, 15 insertions(+), 15 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d1e1c5b5103a..73416d7fad2c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -204,7 +204,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \param buffer_var The input buffer variable. * \return A string representing the storage scope of this buffer variable. */ -TVM_DLL String GetStorageScope(Var buffer_var); +TVM_DLL String GetPtrStorageScope(Var buffer_var); TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); /*! diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 4b182e17ed74..9aec8f4e867b 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -84,7 +84,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a9f9ebb0..bdae93b82aff 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -502,7 +502,7 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; *p_native_bits = - NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); + NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 18faf34143f0..43ea0e6b7ae9 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -59,7 +59,7 @@ class CodeGenNVPTX : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 66b401c731e3..d7dcbec7ebe3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,7 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - std::string scope = GetStorageScope(op->buffer_var); + std::string scope = GetPtrStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 7c9dfcaf95e0..cc20b985e3c6 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -647,7 +647,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 670e503d6114..abf75924b582 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -52,7 +52,7 @@ Buffer decl_buffer(Array shape, DataType dtype, String name, String st Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } -String GetStorageScope(Var buffer_var) { +String GetPtrStorageScope(Var buffer_var) { auto type = buffer_var->type_annotation; const auto* ptr_type = type.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9488d1e99fcb..2749e148a78d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -54,8 +54,8 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto new_scope = GetStorageScope(remapped); - if (new_scope != GetStorageScope(op->buffer_var)) { + auto new_scope = GetPtrStorageScope(remapped); + if (new_scope != GetPtrStorageScope(op->buffer_var)) { Stmt body = StmtExprMutator::VisitStmt(op->body); if (new_scope == "shared") { // use volatile access to shared buffer. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 8f7382a21153..49bc317a1e9c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -58,7 +58,7 @@ class RemapStorageScope final : public StmtExprMutator { auto it = new_var_remap_.find(buf); if (it != new_var_remap_.end()) { auto remapped = it->second; - auto new_scope = GetStorageScope(remapped); + auto new_scope = GetPtrStorageScope(remapped); return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), StmtMutator::VisitStmt(op->body)); } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 952758c4e5a7..9dae0006facd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -242,7 +242,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { if (buffer_var->type_annotation.as()) { - return StorageScope::Create(GetStorageScope(buffer_var)); + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } return StorageScope(); // global by default } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index eca3bba83583..3eccf300639a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -156,7 +156,7 @@ class StorageFlattener : public StmtExprMutator { } // deduce current storage scope. StorageScope skey; - std::string strkey = GetStorageScope(op->buffer->data); + std::string strkey = GetPtrStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index a18bc84604aa..613d02614b39 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -706,7 +706,7 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); const AllocateNode* alloc = alloc_info.at(var).alloc; - auto storage_scope = StorageScope::Create(GetStorageScope(GetRef(var))); + auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -923,7 +923,7 @@ class VectorAllocRewriter : public StmtExprMutator { // create a new buffer var DataType new_dtype = tvec[0]; Var new_buffer_var(op->buffer_var->name_hint, - PointerType(PrimType(new_dtype), GetStorageScope(op->buffer_var))); + PointerType(PrimType(new_dtype), GetPtrStorageScope(op->buffer_var))); // update the remap req. var_remap_.Set(op->buffer_var, new_buffer_var); return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 896224c0e956..ba033f7e97e5 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -286,7 +286,7 @@ class ThreadSyncInserter : public StmtExprMutator { // Get current storage scope. StorageScope GetScope(Var buffer_var) const { - return StorageScope::Create(GetStorageScope(buffer_var)); + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } // private functions. From afbe7f1d1543c3bc7d0dc7a52765995ed45f237f Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:34:18 +0900 Subject: [PATCH 33/90] Enable storage scope invariant check in AttrStmt constructor --- src/tir/ir/stmt.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 18a5c2691005..1946faddb6bb 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -62,12 +62,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { // TODO(masahi): Enable this invariant check - // if (attr_key == attr::storage_scope) { - // const VarNode* buf = node.as(); - // ICHECK(buf); - // ICHECK(value.as()->value == GetStorageScope(GetRef(buf))) - // << value.as()->value << ", " << GetStorageScope(GetRef(buf)); - // } + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + ICHECK(buf); + auto attr_scope = value.as()->value; + auto buffer_scope = GetPtrStorageScope(GetRef(buf)); + ICHECK(attr_scope == buffer_scope) + << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope + << ", " << buffer_scope; + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); From ee3aa5d601eb07a926026eaa7e425cb26ea09dfc Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:02:48 +0900 Subject: [PATCH 34/90] remove GetPtrStorageScope and WithStorageScope from public header --- include/tvm/tir/buffer.h | 8 +------- src/target/source/codegen_c.h | 1 + src/target/spirv/codegen_spirv.cc | 2 +- src/tir/ir/buffer.cc | 14 -------------- src/tir/ir/stmt.cc | 8 ++++---- src/tir/transforms/ir_utils.cc | 6 ++++++ src/tir/transforms/ir_utils.h | 6 ++++++ src/tir/transforms/lower_thread_allreduce.cc | 7 +++++++ src/tir/transforms/lower_warp_memory.cc | 8 ++++++++ src/tir/transforms/thread_storage_sync.cc | 1 - 10 files changed, 34 insertions(+), 27 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 73416d7fad2c..d26221d6a4ff 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -199,13 +199,7 @@ class Buffer : public ObjectRef { TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), String name = "buffer", String storage_scope = "", Span span = Span()); -/*! - * \brief Return the storage scope associated with a buffer variable. - * \param buffer_var The input buffer variable. - * \return A string representing the storage scope of this buffer variable. - */ -TVM_DLL String GetPtrStorageScope(Var buffer_var); -TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); + /*! * \brief Base node for data producers. diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index ae451f39f89b..834c57ac10fd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -39,6 +39,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "codegen_source_base.h" namespace tvm { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index cc20b985e3c6..c1fa921d4507 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,7 +23,6 @@ */ #include "codegen_spirv.h" -#include #include #include #include @@ -33,6 +32,7 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../runtime/vulkan/vulkan_shader.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { namespace codegen { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index abf75924b582..e2fcf89d8966 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -52,20 +52,6 @@ Buffer decl_buffer(Array shape, DataType dtype, String name, String st Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } -String GetPtrStorageScope(Var buffer_var) { - auto type = buffer_var->type_annotation; - const auto* ptr_type = type.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return ptr_type->storage_scope; -} - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1946faddb6bb..e39e9b608474 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,15 +61,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - // TODO(masahi): Enable this invariant check if (attr_key == attr::storage_scope) { const VarNode* buf = node.as(); ICHECK(buf); + const auto* ptr_type = buf->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; auto attr_scope = value.as()->value; - auto buffer_scope = GetPtrStorageScope(GetRef(buf)); - ICHECK(attr_scope == buffer_scope) + ICHECK(attr_scope == ptr_type->storage_scope) << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope - << ", " << buffer_scope; + << ", " << ptr_type->storage_scope; } auto n = make_object(); n->node = node; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..f7ece25d3fcd 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -201,5 +201,11 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } +String GetPtrStorageScope(Var buffer_var) { + const auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 906ff8a38b6c..b5a154b707af 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -191,6 +191,12 @@ inline PrimExpr StackAlloca(std::string type, size_t num) { */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +String GetPtrStorageScope(Var buffer_var); } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 2749e148a78d..ebd3ba88d7f7 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -39,6 +39,13 @@ namespace tir { namespace { +Var WithStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 49bc317a1e9c..3962472e57e6 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -40,12 +40,20 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" namespace tvm { namespace tir { namespace { +Var WithStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index ba033f7e97e5..35e4563b8f58 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include #include From accfff4b19f3051e089b29b826ab9ea60e1da86a Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:17:18 +0900 Subject: [PATCH 35/90] move RemapStorageScope to its own file --- include/tvm/tir/buffer.h | 2 - src/tir/transforms/lower_warp_memory.cc | 42 ++--------- .../transforms/remap_pointer_storage_scope.cc | 69 +++++++++++++++++++ .../transforms/remap_pointer_storage_scope.h | 44 ++++++++++++ 4 files changed, 117 insertions(+), 40 deletions(-) create mode 100644 src/tir/transforms/remap_pointer_storage_scope.cc create mode 100644 src/tir/transforms/remap_pointer_storage_scope.h diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d26221d6a4ff..2507262c087f 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -199,8 +199,6 @@ class Buffer : public ObjectRef { TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), String name = "buffer", String storage_scope = "", Span span = Span()); - - /*! * \brief Base node for data producers. * diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 3962472e57e6..bc43d4da5311 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -41,45 +41,11 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "remap_pointer_storage_scope.h" namespace tvm { namespace tir { -namespace { - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - -class RemapStorageScope final : public StmtExprMutator { - public: - explicit RemapStorageScope(const std::unordered_map& new_var_remap) - : new_var_remap_(new_var_remap) {} - - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = new_var_remap_.find(buf); - if (it != new_var_remap_.end()) { - auto remapped = it->second; - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } - } - return StmtMutator::VisitStmt_(op); - } - - private: - std::unordered_map new_var_remap_; -}; - -} // namespace - // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -392,7 +358,7 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } - std::unordered_map new_var_remap_; + std::unordered_map new_storage_scopes_; private: Stmt VisitStmt_(const AllocateNode* op) { @@ -412,7 +378,7 @@ class WarpMemoryRewriter : private StmtMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - new_var_remap_[buf] = WithStorageScope(GetRef(buf), "local"); + new_storage_scopes_[buf] = "local"; } } return StmtMutator::VisitStmt_(op); @@ -435,7 +401,7 @@ Pass LowerWarpMemory() { int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); - n->body = RemapStorageScope(warp_memory_rewriter.new_var_remap_)(stmt); + n->body = RemapStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc new file mode 100644 index 000000000000..70250faeca1f --- /dev/null +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * TODO + * \file remap_pointer_storage_scope.cc + */ +#include "remap_pointer_storage_scope.h" + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + +RemapStorageScope::RemapStorageScope( + const std::unordered_map& new_storage_scopes) { + for (auto kv : new_storage_scopes) { + new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); + } +} + +Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto it = new_var_remap_.find(buf); + if (it != new_var_remap_.end()) { + auto remapped = it->second; + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + } + return StmtMutator::VisitStmt_(op); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h new file mode 100644 index 000000000000..051f757ddc0c --- /dev/null +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * TODO + * \file remap_pointer_storage_scope.h + */ +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); + + virtual Stmt VisitStmt_(const AttrStmtNode* op); + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace tir +} // namespace tvm From a74aecedee4c24566c29675e46bb98ee29d93c6a Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:27:04 +0900 Subject: [PATCH 36/90] add more method to RemapStorageScope --- .../transforms/remap_pointer_storage_scope.cc | 37 +++++++++++++++---- .../transforms/remap_pointer_storage_scope.h | 6 ++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc index 70250faeca1f..61b22f078931 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -50,20 +50,43 @@ RemapStorageScope::RemapStorageScope( } } +PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { + auto it = new_var_remap_.find(op); + if (it == new_var_remap_.end()) { + return GetRef(op); + } + return it->second; +} + +PrimExpr RemapStorageScope::VisitExpr_(const LoadNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); +} + Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - auto it = new_var_remap_.find(buf); - if (it != new_var_remap_.end()) { - auto remapped = it->second; - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } + auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); } return StmtMutator::VisitStmt_(op); } +Stmt RemapStorageScope::VisitStmt_(const AllocateNode* op) { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + StmtExprMutator::VisitStmt(op->body)); +} + +Stmt RemapStorageScope::VisitStmt_(const StoreNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h index 051f757ddc0c..9689effd11fe 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -34,7 +34,11 @@ class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); - virtual Stmt VisitStmt_(const AttrStmtNode* op); + virtual PrimExpr VisitExpr_(const VarNode*); + virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AttrStmtNode*); + virtual Stmt VisitStmt_(const AllocateNode*); + virtual Stmt VisitStmt_(const StoreNode*); private: std::unordered_map new_var_remap_; From 98c3c3cc9bd180a94b759237553b2c4cda1b701b Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:34:32 +0900 Subject: [PATCH 37/90] update lower_thread_allreduce to use RemapStorageScope --- src/tir/transforms/lower_thread_allreduce.cc | 52 ++++--------------- .../transforms/remap_pointer_storage_scope.cc | 2 +- .../transforms/remap_pointer_storage_scope.h | 8 ++- 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index ebd3ba88d7f7..d4c1c6aea594 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,31 +33,16 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "remap_pointer_storage_scope.h" namespace tvm { namespace tir { -namespace { - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - -class RemapStorageScope final : public StmtExprMutator { +class RemapStorageScopeAllReduce final : public RemapStorageScope { public: - explicit RemapStorageScope(const std::unordered_map& new_var_remap) - : new_var_remap_(new_var_remap) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = new_var_remap_.find(op); - if (it == new_var_remap_.end()) { - return GetRef(op); - } - return it->second; - } + explicit RemapStorageScopeAllReduce( + const std::unordered_map& new_storage_scopes) + : RemapStorageScope(new_storage_scopes) {} Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); @@ -73,25 +58,8 @@ class RemapStorageScope final : public StmtExprMutator { } return StmtExprMutator::VisitStmt_(op); } - - Stmt VisitStmt_(const StoreNode* op) final { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), - StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), - StmtExprMutator::VisitExpr(op->predicate)); - } - - private: - std::unordered_map new_var_remap_; }; -} // namespace - class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -141,10 +109,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "shared"); + new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; } else { @@ -161,7 +129,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } - std::unordered_map new_var_remap_; + std::unordered_map new_storage_scopes_; private: // Thread entry @@ -421,7 +389,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } } @@ -647,7 +615,7 @@ Pass LowerThreadAllreduce() { const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); auto reduce_body = thread_all_reduce(n->body); - n->body = RemapStorageScope(thread_all_reduce.new_var_remap_)(reduce_body); + n->body = RemapStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc index 61b22f078931..8225cddeb094 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -28,7 +28,7 @@ #include #include -#include +#include #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h index 9689effd11fe..0756810d8c9e 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -21,16 +21,19 @@ * TODO * \file remap_pointer_storage_scope.h */ +#ifndef TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ + #include #include #include -#include +#include namespace tvm { namespace tir { -class RemapStorageScope final : public StmtExprMutator { +class RemapStorageScope : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); @@ -46,3 +49,4 @@ class RemapStorageScope final : public StmtExprMutator { } // namespace tir } // namespace tvm +#endif // TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ From bafdb47071b0d20e123cd7cfb85a306fdfcdab50 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 09:26:48 +0900 Subject: [PATCH 38/90] RemapStorageScope -> UpdatePointerStorageScope --- python/tvm/tir/ir_builder.py | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 11 ++++++----- src/tir/transforms/lower_warp_memory.cc | 4 ++-- ...cope.cc => update_pointer_storage_scope.cc} | 18 +++++++++--------- ..._scope.h => update_pointer_storage_scope.h} | 15 ++++++++------- 5 files changed, 26 insertions(+), 24 deletions(-) rename src/tir/transforms/{remap_pointer_storage_scope.cc => update_pointer_storage_scope.cc} (84%) rename src/tir/transforms/{remap_pointer_storage_scope.h => update_pointer_storage_scope.h} (74%) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 1573d96e7d0d..03c1339c4d38 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -436,7 +436,7 @@ def pointer(self, content_type, name="ptr", scope=""): The name of the pointer. scope : str, optional - The scope of the buffer. + The scope of the pointer. Returns ------- diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index d4c1c6aea594..25a2f4e060dd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,16 +33,16 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { -class RemapStorageScopeAllReduce final : public RemapStorageScope { +class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { public: - explicit RemapStorageScopeAllReduce( + explicit UpdatePointerStorageScopeAllReduce( const std::unordered_map& new_storage_scopes) - : RemapStorageScope(new_storage_scopes) {} + : UpdatePointerStorageScope(new_storage_scopes) {} Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); @@ -615,7 +615,8 @@ Pass LowerThreadAllreduce() { const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); auto reduce_body = thread_all_reduce(n->body); - n->body = RemapStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); + n->body = + UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index bc43d4da5311..060b02c3d137 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -41,7 +41,7 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { @@ -401,7 +401,7 @@ Pass LowerWarpMemory() { int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); - n->body = RemapStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); + n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc similarity index 84% rename from src/tir/transforms/remap_pointer_storage_scope.cc rename to src/tir/transforms/update_pointer_storage_scope.cc index 8225cddeb094..ae72e7f947cd 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -18,10 +18,10 @@ */ /*! - * TODO - * \file remap_pointer_storage_scope.cc + * \file update_pointer_storage_scope.cc + * \brief A pass to update storage scopes for buffer variables. */ -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" #include #include @@ -43,14 +43,14 @@ Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { buffer_var->span); } -RemapStorageScope::RemapStorageScope( +UpdatePointerStorageScope::UpdatePointerStorageScope( const std::unordered_map& new_storage_scopes) { for (auto kv : new_storage_scopes) { new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); } } -PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { +PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { auto it = new_var_remap_.find(op); if (it == new_var_remap_.end()) { return GetRef(op); @@ -58,13 +58,13 @@ PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { return it->second; } -PrimExpr RemapStorageScope::VisitExpr_(const LoadNode* op) { +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); } -Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); @@ -76,13 +76,13 @@ Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { return StmtMutator::VisitStmt_(op); } -Stmt RemapStorageScope::VisitStmt_(const AllocateNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } -Stmt RemapStorageScope::VisitStmt_(const StoreNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h similarity index 74% rename from src/tir/transforms/remap_pointer_storage_scope.h rename to src/tir/transforms/update_pointer_storage_scope.h index 0756810d8c9e..481536a45b27 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -18,11 +18,11 @@ */ /*! - * TODO - * \file remap_pointer_storage_scope.h + * \file update_pointer_storage_scope.h + * \brief A pass to update storage scopes for buffer variables. */ -#ifndef TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ -#define TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#ifndef TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ #include #include @@ -33,9 +33,10 @@ namespace tvm { namespace tir { -class RemapStorageScope : public StmtExprMutator { +class UpdatePointerStorageScope : public StmtExprMutator { public: - explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); + explicit UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes); virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); @@ -49,4 +50,4 @@ class RemapStorageScope : public StmtExprMutator { } // namespace tir } // namespace tvm -#endif // TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#endif // TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ From 16c81b573e22ba86f9f9d1878a28f5176184c95d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 14:57:50 +0900 Subject: [PATCH 39/90] remove realize_scope from hybrid script --- python/tvm/te/hybrid/parser.py | 3 +-- src/contrib/hybrid/codegen_hybrid.cc | 9 ++------- src/contrib/hybrid/codegen_hybrid.h | 2 -- tests/python/unittest/test_te_hybrid_script.py | 4 +--- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) - body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7522f20523c8..54edbaee35cd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); - alloc_storage_scope_[v] = op->value.as()->value; - PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { auto tensor = Downcast(op->producer); - ICHECK(alloc_storage_scope_.count(tensor->op)); - if (!alloc_storage_scope_[tensor->op].empty()) { + if (!op->storage_scope.empty()) { PrintIndent(); stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { stream << "), '"; PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[tensor->op] << "')\n"; + stream << op->storage_scope << "')\n"; } PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index b01ca2763e28..47c13f73022f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor, * \param tensor The tensor to allocate a name. */ std::string GetTensorID(const Tensor& tensor); - /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 30b96546f991..e9626e7f31b4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -189,9 +189,7 @@ def fanout(n, a): assert ir.min.value == 0 assert tvm.ir.structural_equal(ir.extent, n - 3) # Check loopbody - ibody = ir.body - assert isinstance(ibody, tvm.tir.AttrStmt) - abody = ibody.body + abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 From a14c5fa140c5ebab321d7494a20dadf0de033ae5 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:23:53 +0900 Subject: [PATCH 40/90] removed realize_scope in schedule_ops --- docs/dev/inferbound.rst | 2 -- src/te/schedule/schedule_ops.cc | 5 +---- tests/python/unittest/test_te_schedule_tensorize.py | 4 ++-- tests/python/unittest/test_te_tensor.py | 2 +- tests/python/unittest/test_tir_transform_loop_partition.py | 4 ++-- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index 010d0d42d37e..28e034dc44cb 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser :: - // attr [compute(D, 0x2c070b0)] realize_scope = "" realize D([0, 4], [0, 5], [0, 16]) { produce D { for (di, 0, 4) { for (dj, 0, 5) { for (dk, 0, 16) { - // attr [compute(C, 0x2c29990)] realize_scope = "" realize C([dj, 1], [dk, 1]) { produce C { for (i, 0, 1) { diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 21edd2f94b20..9faff741372b 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline, s->scope); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index e2c2f7f7e0e5..ae5e7051bfba 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -379,8 +379,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body, tvm.tir.For) - assert stmt.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body, tvm.tir.For) + assert stmt.body.loop_var.name == C.op.axis[0].var.name if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index ed4a21397885..2931925965b7 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -309,7 +309,7 @@ def get_B1_realize(x): ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) - assert stmt.node == C.op and len(ret) == 1 + assert stmt.producer == C and len(ret) == 1 def test_tensor_inputs(): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 9e8848083908..6194024748e0 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -156,7 +156,7 @@ def test_thread_axis(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) From c2ea828b113ffaab04ddcd98db230d64113e2ffb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:43:51 +0900 Subject: [PATCH 41/90] remove realize_scope from schedule_postproc_to_primfunc --- src/te/schedule/schedule_postproc_to_primfunc.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 8e6cc131b76e..2063fc7cad6a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { From d401022bb748125e37af8033d2da3f8650b7a73e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:52:57 +0900 Subject: [PATCH 42/90] remove remaining realize_scope usage from schedule_ops.cc --- src/te/schedule/schedule_ops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 9faff741372b..825092d20ac0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -172,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { From de8362367c0ed8bbff72730ea883cbef4113644e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 16:06:21 +0900 Subject: [PATCH 43/90] remove realize_scope usage from storage_flatten.cc --- src/tir/transforms/storage_flatten.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 3eccf300639a..0db86130a8da 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,10 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); From 426b573476e34e317e757b7570863659fff0cc8f Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 17:56:53 +0900 Subject: [PATCH 44/90] fixed test_tir_transform_lower_warp_memory.py following realize_scope removal --- tests/python/unittest/test_tir_transform_lower_warp_memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ef474c15cfbb..f3baff120cf6 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -72,8 +72,8 @@ def test_lower_warp_memory_correct_indices(): bounds = tvm.te.schedule.InferBound(s) ir = tvm.te.schedule.ScheduleOps(s, bounds) - inner_func = ir.body.body.body.body - store_A_warp = inner_func.body.seq[0].body.body + inner_func = ir.body.body.body + store_A_warp = inner_func.seq[0].body.body indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp From 1964f7473cdfab87bda233723959b2b9a5440357 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:24:41 +0900 Subject: [PATCH 45/90] Add storage scope to ProducerRealize, always create a buffer with scope --- include/tvm/te/operation.h | 15 ++++++------ include/tvm/tir/buffer.h | 10 +++++++- include/tvm/tir/stmt.h | 9 ++++++-- python/tvm/tir/ir_builder.py | 2 +- python/tvm/tir/stmt.py | 7 ++++-- src/runtime/thread_storage_scope.h | 4 +++- src/te/operation/compute_op.cc | 4 ++-- src/te/operation/extern_op.cc | 4 ++-- src/te/operation/hybrid_op.cc | 4 ++-- src/te/operation/placeholder_op.cc | 2 +- src/te/operation/scan_op.cc | 4 ++-- src/tir/ir/buffer.cc | 12 ++++++++-- src/tir/ir/stmt.cc | 7 +++--- src/tir/transforms/thread_storage_sync.cc | 28 ++++++++--------------- 14 files changed, 66 insertions(+), 46 deletions(-) diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..13f39317dbe4 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object { * \param stage the op's stage. * \param realize_map The realization domain map of the operators. * \param body The body that is going to get + * \param storage_scope The storage scope associated with this realization * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 017f4f7052b1..ed5718d8f358 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -191,12 +191,20 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param storage_scope The storage scope associated with this buffer * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); + +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +TVM_DLL String GetStorageScope(Var buffer_var); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..9997a4d95694 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -464,18 +464,22 @@ class ProducerRealizeNode : public StmtNode { PrimExpr condition; /*! \brief The body of realization. */ Stmt body; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4934bf04727f..5aae068f4d58 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index dd7665a56692..94074b906777 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -364,13 +364,16 @@ class ProducerRealize(Stmt): body : Stmt The realize body + storage_scope : str + The storage scope associated with this realization + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body, span=None): + def __init__(self, producer, bounds, condition, body, storage_scope="", span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body, span + _ffi_api.ProducerRealize, producer, bounds, condition, body, storage_scope, span ) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c0393600b60c..d93a1f130bae 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -118,7 +118,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s == "") { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..26c08955f5ad 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..851d440a6378 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,12 +45,20 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } +String GetStorageScope(Var buffer_var) { + auto type = buffer_var->type_annotation; + const auto* ptr_type = type.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..6fdeb30ec100 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -377,7 +377,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +394,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + String storage_scope, Span span) { + return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..896224c0e956 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,6 +22,7 @@ */ #include #include +#include #include #include #include @@ -223,14 +224,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +251,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +261,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +283,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +331,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier From 63e0c85e44aeb0573bb1c2e361d9545c8c71567b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:26:38 +0900 Subject: [PATCH 46/90] update schedule_ops.cc --- src/te/schedule/schedule_ops.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..f130e1fb93e4 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,7 +51,7 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); + pipeline = s->op->BuildRealize(s, dom_map, pipeline, s->scope); // use attribute to mark scope of the operation. pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); @@ -175,8 +175,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +217,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = + ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); From 24225e5cb5a5f4e38969549d357cc9d7cd9e9bf2 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:31:11 +0900 Subject: [PATCH 47/90] update schedule_postproc_to_primfunc.cc --- src/te/schedule/schedule_postproc_to_primfunc.cc | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..8e6cc131b76e 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -95,7 +95,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +122,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } From 9ddfb283d169a929e7081b09d084e9c00a88f7cb Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 06:41:12 +0900 Subject: [PATCH 48/90] restore more realize_scope This reverts commit b66c3baa54feeb8e34016713a1be21802b3296bf. --- src/te/schedule/schedule_ops.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index f130e1fb93e4..21edd2f94b20 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -175,7 +175,8 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::realize_scope || + op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { From 416169c91ae12b2e80558749435f7507e2f22a26 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:21:31 +0900 Subject: [PATCH 49/90] make the default scope be "" instead of None in ir builder --- python/tvm/tir/ir_builder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 5aae068f4d58..484d00f9611a 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -394,7 +394,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=None): + def allocate(self, dtype, shape, name="buf", scope=""): """Create a allocate statement. Parameters From da3053e62353d6a88c9fb10f3b863a44e8beace3 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:29:31 +0900 Subject: [PATCH 50/90] restore realize_scope visit in storage_flatten.cc --- src/tir/transforms/storage_flatten.cc | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..fab007c5e4d3 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -79,7 +79,6 @@ class StorageFlattener : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; return this->VisitStmt(op->body); } else if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { @@ -156,10 +155,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -491,8 +488,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. From 76ed1ab000d2dea43a788761b117a1812ef363d0 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:30:11 +0900 Subject: [PATCH 51/90] update storage_access.cc --- src/tir/transforms/storage_access.cc | 20 ++++++-------------- src/tir/transforms/storage_access.h | 4 +--- 2 files changed, 7 insertions(+), 17 deletions(-) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..8f5b8d75c1d4 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -208,7 +204,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +240,8 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + return StorageScope::Create(GetStorageScope(buffer_var)); } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir From e95146e062b66baaf51908c6894a49fa5cfe5df9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 07:58:19 +0900 Subject: [PATCH 52/90] make sure buffer var is of PointerType in ir builder This reverts commit e650b6c24cabd52a073064e51c2e4fee816e88fd. --- python/tvm/tir/ir_builder.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 484d00f9611a..1573d96e7d0d 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -424,7 +424,7 @@ def allocate(self, dtype, shape, name="buf", scope=""): self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) - def pointer(self, content_type, name="ptr"): + def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. Parameters @@ -435,12 +435,15 @@ def pointer(self, content_type, name="ptr"): name : str, optional The name of the pointer. + scope : str, optional + The scope of the buffer. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) return BufferVar(self, buffer_var, None, content_type) def buffer_ptr(self, buf, shape=None): From 13b2efc7389b8b6aca7b5212fed4ac30a14b83ac Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 1 Jul 2021 14:14:38 +0900 Subject: [PATCH 53/90] enforce default storage scope of global --- include/tvm/ir/type.h | 2 +- include/tvm/tir/buffer.h | 3 ++- python/tvm/tir/buffer.py | 4 ++-- src/ir/type.cc | 3 ++- src/te/operation/cross_thread_reduction.cc | 4 ++-- src/te/schedule/schedule_postproc_to_primfunc.cc | 7 ++++--- src/tir/ir/buffer.cc | 8 ++++++++ src/tir/ir/stmt.cc | 6 ++++++ src/tir/transforms/lower_thread_allreduce.cc | 12 ++++++------ src/tir/transforms/storage_flatten.cc | 4 ++++ 10 files changed, 37 insertions(+), 16 deletions(-) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index c772650809fa..2c6e0c35a280 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -184,7 +184,7 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); + TVM_DLL explicit PointerType(Type element_type, String storage_scope = "global"); TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index ed5718d8f358..c66fa73d8096 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", Span span = Span()); + String name = "buffer", String storage_scope = "global", Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. @@ -205,6 +205,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \return A string representing the storage scope of this buffer variable. */ TVM_DLL String GetStorageScope(Var buffer_var); +TVM_DLL Var UpdateStorageScope(Var buffer_var, String storage_scope); /*! * \brief Base node for data producers. diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d905a53b3303..9c78f8511903 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -140,7 +140,7 @@ def decl_buffer( data=None, strides=None, elem_offset=None, - scope="", + scope="global", data_alignment=-1, offset_factor=0, buffer_type="", @@ -250,7 +250,7 @@ def decl_buffer( # Bool is represented as uint1 in the IR, but stored as int8 storage_type = PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type - data = Var(name, PointerType(storage_type), span) + data = Var(name, PointerType(storage_type, scope), span) return _ffi_api.Buffer( data, dtype, diff --git a/src/ir/type.cc b/src/ir/type.cc index fe8e00329bbc..3f450cdf0392 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -44,6 +44,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); PointerType::PointerType(Type element_type, String storage_scope) { + ICHECK(storage_scope != ""); ObjectPtr n = make_object(); n->element_type = std::move(element_type); n->storage_scope = std::move(storage_scope); @@ -53,7 +54,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); TVM_REGISTER_GLOBAL("ir.PointerType") - .set_body_typed([](Type element_type, String storage_scope = "") { + .set_body_typed([](Type element_type, String storage_scope = "global") { return PointerType(element_type, storage_scope); }); diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index da20dd875ba5..a6ee10edd5a3 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t))); + PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; @@ -177,7 +177,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); + res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 8e6cc131b76e..e9caeabcabd0 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,7 +49,7 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "global") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); @@ -122,11 +122,12 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "global") { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "global", + bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 851d440a6378..704afed689cd 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,6 +48,7 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); + if (storage_scope == "") storage_scope = "global"; return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } @@ -59,6 +60,13 @@ String GetStorageScope(Var buffer_var) { return ptr_type->storage_scope; } +Var UpdateStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 6fdeb30ec100..08d8e15dd2b7 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,6 +61,12 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + CHECK(buf); + CHECK(value.as()->value == GetStorageScope(GetRef(buf))) + << value.as()->value << ", " << GetStorageScope(GetRef(buf)); + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9e536814fa12..4a1b31fb8dd4 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -85,13 +85,13 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + stmt = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, + repl->condition, op->body); } else { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = Allocate(UpdateStorageScope(repl->buffer_var, "shared"), repl->dtype, repl->extents, + repl->condition, stmt); } return stmt; } else { @@ -365,8 +365,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + body = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, + repl->condition, body); } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index fab007c5e4d3..57fae5069e3a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -197,6 +197,7 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } + LOG(INFO) << "skey: " << skey.to_string(); e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); @@ -225,6 +226,9 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } + CHECK(e.buffer->scope == GetStorageScope(e.buffer->data)) + << e.buffer->scope << ", " << GetStorageScope(e.buffer->data) << ", " + << GetStorageScope(op->buffer->data); ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { From ae8858a378c9574d0e7f401a053a07a210c3f717 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 07:03:53 +0900 Subject: [PATCH 54/90] added remap pass but does not work yet --- src/tir/transforms/lower_thread_allreduce.cc | 58 +++++++++++++++++--- 1 file changed, 51 insertions(+), 7 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 4a1b31fb8dd4..a9f7671b519e 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -37,6 +37,44 @@ namespace tvm { namespace tir { +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_var_remap) + : new_var_remap_(new_var_remap) {} + + PrimExpr VisitExpr_(const VarNode* op) final { + auto it = new_var_remap_.find(op); + LOG(INFO) << "Visit " << op->name_hint; + if (it == new_var_remap_.end()) { + return GetRef(op); + } + LOG(INFO) << "Remapped " << op->name_hint; + return it->second; + } + + Stmt VisitStmt_(const AllocateNode* op) final { + LOG(INFO) << "Visit alloc node with buffer " << op->buffer_var->name_hint; + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + auto body = StmtExprMutator::VisitStmt(op->body); + return Allocate(Downcast(remapped), op->dtype, op->extents, op->condition, body); + } + + Stmt VisitStmt_(const StoreNode* op) final { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); + } + + private: + std::unordered_map new_var_remap_; +}; + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -85,13 +123,14 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { if (it != alloc_remap_.end()) { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { - stmt = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, - repl->condition, op->body); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } else { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(UpdateStorageScope(repl->buffer_var, "shared"), repl->dtype, repl->extents, - repl->condition, stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; } else { @@ -108,6 +147,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + std::unordered_map new_var_remap_; + private: // Thread entry struct ThreadEntry { @@ -365,8 +406,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { for (auto var : local_vars) { const AllocateNode* repl = var.as(); if (repl) { - body = Allocate(UpdateStorageScope(repl->buffer_var, "local"), repl->dtype, repl->extents, - repl->condition, body); + body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); + LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; + new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } @@ -590,7 +632,9 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); - n->body = ThreadAllreduceBuilder(target_node)(n->body); + ThreadAllreduceBuilder thread_all_reduce(target_node); + auto reduce_body = thread_all_reduce(n->body); + n->body = RemapStorageScope(thread_all_reduce.new_var_remap_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); From cec5a0a9245fc6f1dbb04ccc1abe8db632c30726 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:40:01 +0900 Subject: [PATCH 55/90] fixed all reduce issue This reverts commit 8e20003c5325085ed22ee57180aca18644b3b5ab. --- src/tir/transforms/lower_thread_allreduce.cc | 12 +++++------- src/tir/transforms/storage_flatten.cc | 16 ++++++---------- 2 files changed, 11 insertions(+), 17 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index a9f7671b519e..5a2bd8ac94c1 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -44,19 +44,17 @@ class RemapStorageScope final : public StmtExprMutator { PrimExpr VisitExpr_(const VarNode* op) final { auto it = new_var_remap_.find(op); - LOG(INFO) << "Visit " << op->name_hint; if (it == new_var_remap_.end()) { return GetRef(op); } - LOG(INFO) << "Remapped " << op->name_hint; return it->second; } Stmt VisitStmt_(const AllocateNode* op) final { - LOG(INFO) << "Visit alloc node with buffer " << op->buffer_var->name_hint; - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); auto body = StmtExprMutator::VisitStmt(op->body); - return Allocate(Downcast(remapped), op->dtype, op->extents, op->condition, body); + auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); } Stmt VisitStmt_(const StoreNode* op) final { @@ -129,7 +127,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; + LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; @@ -407,7 +405,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; + LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 57fae5069e3a..91d5792df97a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,10 +157,8 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if (strkey.length() == 0) { - if (curr_thread_scope_.size() != 0) { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); - } + if(curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } else { skey = StorageScope::Create(strkey); } @@ -197,9 +195,10 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - LOG(INFO) << "skey: " << skey.to_string(); - e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + auto* ptr_type = op->buffer->data->type_annotation.as(); + ICHECK(ptr_type); + auto new_var = Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; @@ -226,9 +225,6 @@ class StorageFlattener : public StmtExprMutator { ret = Allocate(e.buffer->data, storage_type, shape, make_const(DataType::Bool(e.buffer->dtype.lanes()), true), body); } - CHECK(e.buffer->scope == GetStorageScope(e.buffer->data)) - << e.buffer->scope << ", " << GetStorageScope(e.buffer->data) << ", " - << GetStorageScope(op->buffer->data); ret = AttrStmt(e.buffer->data, attr::storage_scope, StringImm(e.buffer->scope), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { From e4f09653db314b81e8f02b70b4aabf500e744db6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:48:11 +0900 Subject: [PATCH 56/90] simplify --- src/tir/transforms/lower_thread_allreduce.cc | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 5a2bd8ac94c1..8cb7db902db6 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -52,8 +52,8 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto body = StmtExprMutator::VisitStmt(op->body); - auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, body); + auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, + StmtExprMutator::VisitStmt(op->body)); return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); } @@ -127,7 +127,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { // use volatile access to shared buffer. stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - LOG(INFO) << "make remap for " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; @@ -405,7 +404,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - LOG(INFO) << "make remap forr " << repl->buffer_var->name_hint; new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } } From fca791c96eb3fea123d21a37feebfb052494dfc4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 13:50:13 +0900 Subject: [PATCH 57/90] trying mitigation for aot test --- src/tir/transforms/storage_access.cc | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 8f5b8d75c1d4..952758c4e5a7 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -241,7 +241,10 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { - return StorageScope::Create(GetStorageScope(buffer_var)); + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetStorageScope(buffer_var)); + } + return StorageScope(); // global by default } } // namespace tir From a80dfad28b73a4d51ad58e8c3316e640c8cc4f95 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:00:52 +0900 Subject: [PATCH 58/90] merge remaining changes from initial branch --- src/target/spirv/codegen_spirv.cc | 14 +++++----- src/target/spirv/codegen_spirv.h | 2 -- src/tir/transforms/storage_rewrite.cc | 37 +++++++++------------------ 3 files changed, 19 insertions(+), 34 deletions(-) diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d52bee44e98..7c9dfcaf95e0 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,6 +23,7 @@ */ #include "codegen_spirv.h" +#include #include #include #include @@ -644,13 +645,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + spirv::Value buf; - StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (info.scope.rank == runtime::StorageRank::kShared) { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); @@ -660,8 +662,10 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { builder_->SetName(buf, op->buffer_var->name_hint); + StorageInfo& info = storage_info_[op->buffer_var.get()]; ICHECK(!info.content_fixed); info.UpdateContentType(op->dtype); + ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -677,10 +681,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3868322a74e0..a44dc5fd3d34 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -116,8 +116,6 @@ class CodeGenSPIRV : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief Whether it is volatile */ bool is_volatile{false}; /*! \brief Whether it is volatile */ diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c755576e2b88..d39d1f2ed901 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -75,8 +75,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +84,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +173,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -409,10 +398,8 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::storage_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; @@ -716,7 +703,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +714,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +731,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } From 6fd0fb3d3906002480be8bdb3bfa3563d76609d4 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:11:50 +0900 Subject: [PATCH 59/90] remove use of attr::storage_scope from codegen --- src/target/llvm/codegen_amdgpu.cc | 27 ++++++++-------- src/target/llvm/codegen_cpu.cc | 6 ++-- src/target/llvm/codegen_llvm.cc | 53 ++++++++++++++----------------- src/target/llvm/codegen_llvm.h | 11 ++----- src/target/llvm/codegen_nvptx.cc | 26 +++++++-------- src/target/source/codegen_cuda.cc | 8 ++--- 6 files changed, 58 insertions(+), 73 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 78f8a50e4e1b..01d2b2f7ad4d 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -76,30 +76,31 @@ class CodeGenAMDGPU : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; @@ -108,11 +109,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalVariable* global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(info.alignment)) { + if (global->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(alignment); #endif } buf = global; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index ab96d6e69d14..b9761355b208 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -463,9 +463,9 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } // Add alignment attribute if needed. #if TVM_LLVM_VERSION >= 50 - auto f = alloc_storage_info_.find(var.get()); - if (f != alloc_storage_info_.end()) { - unsigned align = f->second.alignment; + auto f = alloc_storage_alignment_.find(var.get()); + if (f != alloc_storage_alignment_.end()) { + unsigned align = f->second; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48ccefafe3c4..545c94dddae3 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -104,7 +104,7 @@ void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - alloc_storage_info_.clear(); + alloc_storage_alignment_.clear(); volatile_buf_.clear(); analyzer_.reset(new arith::Analyzer()); } @@ -165,9 +165,9 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { #if TVM_LLVM_VERSION >= 50 for (size_t i = 0; i < f->params.size(); ++i) { const Var& var = f->params[i]; - auto f = alloc_storage_info_.find(var.get()); - if (f != alloc_storage_info_.end()) { - unsigned align = f->second.alignment; + auto f = alloc_storage_alignment_.find(var.get()); + if (f != alloc_storage_alignment_.end()) { + unsigned align = f->second; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); @@ -498,11 +498,12 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); - auto it = alloc_storage_info_.find(buf_var); - if (it != alloc_storage_info_.end()) { - const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); - max_align_bits = info.alignment * 8; + auto it = alloc_storage_alignment_.find(buf_var); + if (it != alloc_storage_alignment_.end()) { + const int alignment = it->second; + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); + max_align_bits = alignment * 8; } else { *p_native_bits = native_vector_bits_; } @@ -1353,25 +1354,25 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } - info.alignment = alloca->getAlignment(); + alignment = alloca->getAlignment(); buf = alloca; buf = builder_->CreatePointerCast( @@ -1390,18 +1391,13 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); - alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); - if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { + alloc_storage_alignment_[v] = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_alignment_[v] > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_info_[v].alignment); + alloc_storage_alignment_[v]); } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); @@ -1426,9 +1422,8 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { - builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_info_[v].alignment); + if (alloc_storage_alignment_.count(v) && alloc_storage_alignment_[v] > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), alloc_storage_alignment_[v]); } this->VisitStmt(op->body); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d5fcfab6d889..fb13ce42f897 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -161,13 +161,6 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: - /*! \brief The storage information */ - struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; - /*! \brief The alignment of allocation */ - int alignment{0}; - }; /*! * \brief Execute falloca at the beginning of the * currrent function and obtain its return value. @@ -327,8 +320,8 @@ class CodeGenLLVM : public ExprFunctor, std::vector > link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; - /*! \brief the storage scope of allocation */ - std::unordered_map alloc_storage_info_; + /*! \brief the alignment of allocation */ + std::unordered_map alloc_storage_alignment_; // The definition of local variable. std::unordered_map var_map_; // global strings diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 9e56529ec9ef..e8ae088ece32 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -51,31 +51,31 @@ class CodeGenNVPTX : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && alignment == 0) { + alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; + if (alignment > 16) { + alignment = 16; } - - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + if (alloca->getAlignment() < static_cast(alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(alignment); #endif } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; @@ -85,9 +85,9 @@ class CodeGenNVPTX : public CodeGenLLVM { *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(alignment); #endif buf = global; } diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71..66b401c731e3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,12 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; + std::string scope = GetStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,6 +719,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } + const VarNode* buffer = op->buffer_var.as(); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { From b3fa275abe92ee459a75d3a88ee31cc54956684e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:12:19 +0900 Subject: [PATCH 60/90] restore a visit to AttrStmt with attr::storage_scope in storage_rewrite --- src/tir/transforms/storage_rewrite.cc | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index d39d1f2ed901..e99b0c0e1086 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -398,8 +398,10 @@ class StoragePlanRewriter : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || - attr::IsPragmaKey(op->attr_key)) { + if (op->attr_key == attr::storage_scope) { + return this->VisitStmt(op->body); + } else if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread || + attr::IsPragmaKey(op->attr_key)) { // remake all the allocation at the attach scope. if (attach_map_.count(op)) { auto& svec = attach_map_[op]; From b62d8a143579aa72dc363721b0a6801c62ac7c9d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:21:51 +0900 Subject: [PATCH 61/90] disable check --- src/ir/type.cc | 1 - src/tir/ir/stmt.cc | 13 +++++++------ 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/ir/type.cc b/src/ir/type.cc index 3f450cdf0392..567e31d9c2a6 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -44,7 +44,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); PointerType::PointerType(Type element_type, String storage_scope) { - ICHECK(storage_scope != ""); ObjectPtr n = make_object(); n->element_type = std::move(element_type); n->storage_scope = std::move(storage_scope); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 08d8e15dd2b7..18a5c2691005 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,12 +61,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - if (attr_key == attr::storage_scope) { - const VarNode* buf = node.as(); - CHECK(buf); - CHECK(value.as()->value == GetStorageScope(GetRef(buf))) - << value.as()->value << ", " << GetStorageScope(GetRef(buf)); - } + // TODO(masahi): Enable this invariant check + // if (attr_key == attr::storage_scope) { + // const VarNode* buf = node.as(); + // ICHECK(buf); + // ICHECK(value.as()->value == GetStorageScope(GetRef(buf))) + // << value.as()->value << ", " << GetStorageScope(GetRef(buf)); + // } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); From 4a084d72494a0e8c5831745c8a6771750c067387 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:26:23 +0900 Subject: [PATCH 62/90] lint fix --- include/tvm/tir/buffer.h | 3 ++- src/te/operation/cross_thread_reduction.cc | 4 +++- src/tir/transforms/storage_flatten.cc | 5 +++-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index c66fa73d8096..bf5c1ceaaf5c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,8 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "global", Span span = Span()); + String name = "buffer", String storage_scope = "global", + Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index a6ee10edd5a3..0c20328f02b7 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -1,3 +1,4 @@ + /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file @@ -177,7 +178,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); + res_handles[idx] = + Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 91d5792df97a..7af36d164a56 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,7 +157,7 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if(curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { + if (curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); } else { skey = StorageScope::Create(strkey); @@ -197,7 +197,8 @@ class StorageFlattener : public StmtExprMutator { auto* ptr_type = op->buffer->data->type_annotation.as(); ICHECK(ptr_type); - auto new_var = Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + auto new_var = + Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); From 73911a93353d9589b7a2ade19542a20afe958341 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:52:18 +0900 Subject: [PATCH 63/90] revert default scope to "" --- include/tvm/ir/type.h | 2 +- include/tvm/tir/buffer.h | 2 +- python/tvm/tir/buffer.py | 2 +- src/ir/type.cc | 2 +- src/te/schedule/schedule_postproc_to_primfunc.cc | 6 +++--- src/tir/ir/buffer.cc | 1 - src/tir/transforms/storage_flatten.cc | 7 ++++--- 7 files changed, 11 insertions(+), 11 deletions(-) diff --git a/include/tvm/ir/type.h b/include/tvm/ir/type.h index 2c6e0c35a280..c772650809fa 100644 --- a/include/tvm/ir/type.h +++ b/include/tvm/ir/type.h @@ -184,7 +184,7 @@ class PointerType : public Type { * \param element_type The type of the element which the pointer points to. * \param storage_scope The storage scope into which the pointer addresses */ - TVM_DLL explicit PointerType(Type element_type, String storage_scope = "global"); + TVM_DLL explicit PointerType(Type element_type, String storage_scope = ""); TVM_DEFINE_OBJECT_REF_METHODS(PointerType, Type, PointerTypeNode); }; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index bf5c1ceaaf5c..fd6718a44e4b 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,7 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "global", + String name = "buffer", String storage_scope = "", Span span = Span()); /*! diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index 9c78f8511903..eb48e4c8068b 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -140,7 +140,7 @@ def decl_buffer( data=None, strides=None, elem_offset=None, - scope="global", + scope="", data_alignment=-1, offset_factor=0, buffer_type="", diff --git a/src/ir/type.cc b/src/ir/type.cc index 567e31d9c2a6..fe8e00329bbc 100644 --- a/src/ir/type.cc +++ b/src/ir/type.cc @@ -53,7 +53,7 @@ PointerType::PointerType(Type element_type, String storage_scope) { TVM_REGISTER_NODE_TYPE(PointerTypeNode); TVM_REGISTER_GLOBAL("ir.PointerType") - .set_body_typed([](Type element_type, String storage_scope = "global") { + .set_body_typed([](Type element_type, String storage_scope = "") { return PointerType(element_type, storage_scope); }); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index e9caeabcabd0..b80f76f5acbf 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,7 +49,7 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "global") { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); @@ -122,11 +122,11 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "global") { + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "global", + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 704afed689cd..49da7c7f5630 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -48,7 +48,6 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - if (storage_scope == "") storage_scope = "global"; return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 7af36d164a56..eca3bba83583 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -157,12 +157,13 @@ class StorageFlattener : public StmtExprMutator { // deduce current storage scope. StorageScope skey; std::string strkey = GetStorageScope(op->buffer->data); - if (curr_thread_scope_.size() != 0 && (strkey == "" || strkey == "global")) { - skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + if (strkey.length() == 0) { + if (curr_thread_scope_.size() != 0) { + skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); + } } else { skey = StorageScope::Create(strkey); } - // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); From 1c8b7796517dacc7c8deb4606cc0c515246c515d Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 2 Jul 2021 14:56:42 +0900 Subject: [PATCH 64/90] format --- include/tvm/tir/buffer.h | 3 +-- src/te/schedule/schedule_postproc_to_primfunc.cc | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index fd6718a44e4b..f01158967bdd 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -197,8 +197,7 @@ class Buffer : public ObjectRef { * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", String storage_scope = "", - Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); /*! * \brief Return the storage scope associated with a buffer variable. diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index b80f76f5acbf..8e6cc131b76e 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -126,8 +126,7 @@ class TensorToBufferMapper : public StmtExprMutator { return GetBuffer(tensor, storage_scope, true); } - Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", - bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; From 9f54b6225a97e08c9ebd7b772e5c05c94475c3cd Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 4 Jul 2021 10:58:25 +0900 Subject: [PATCH 65/90] fix volatile access to shared mem in lower all reduce --- src/tir/transforms/lower_thread_allreduce.cc | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 8cb7db902db6..68bf24abb847 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -52,9 +52,17 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto stmt = Allocate(remapped, op->dtype, op->extents, op->condition, - StmtExprMutator::VisitStmt(op->body)); - return AttrStmt(remapped, attr::storage_scope, StringImm(GetStorageScope(remapped)), stmt); + auto new_scope = GetStorageScope(remapped); + if (new_scope != GetStorageScope(op->buffer_var)) { + Stmt body = StmtExprMutator::VisitStmt(op->body); + if (new_scope == "shared") { + // use volatile access to shared buffer. + body = AttrStmt(remapped, attr::volatile_scope, 1, body); + } + body = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); + } + return StmtExprMutator::VisitStmt_(op); } Stmt VisitStmt_(const StoreNode* op) final { @@ -124,9 +132,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); } else { - // use volatile access to shared buffer. - stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); } return stmt; From 034fb72769e78f5bb8e281fe88a14716387e6e2c Mon Sep 17 00:00:00 2001 From: masa Date: Sun, 4 Jul 2021 11:48:31 +0900 Subject: [PATCH 66/90] fixed gpu coorporative load/store test --- src/tir/transforms/storage_rewrite.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index e99b0c0e1086..a18bc84604aa 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -922,7 +922,8 @@ class VectorAllocRewriter : public StmtExprMutator { extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); // create a new buffer var DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); + Var new_buffer_var(op->buffer_var->name_hint, + PointerType(PrimType(new_dtype), GetStorageScope(op->buffer_var))); // update the remap req. var_remap_.Set(op->buffer_var, new_buffer_var); return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); From 60815194dbe464357c3086ade3c49848c92a0f4d Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 06:58:44 +0900 Subject: [PATCH 67/90] pass storage scope to PointerType in tvm script parser This reverts commit 99cfb9d18781dcfdea169d920450f9063ab18b6b. --- Jenkinsfile | 1 + python/tvm/script/scope_handler.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/Jenkinsfile b/Jenkinsfile index f26b148085fb..815c07ad8806 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -282,6 +282,7 @@ stage('Unit Test') { timeout(time: max_time, unit: 'MINUTES') { sh "${docker_run} ${ci_arm} ./tests/scripts/task_ci_setup.sh" sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_unittest.sh" + sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_arm_compute_library.sh" junit "build/pytest-results/*.xml" // sh "${docker_run} ${ci_arm} ./tests/scripts/task_python_integration.sh" } diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index a23401d926e9..d07209485bd4 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -140,7 +140,7 @@ def enter_scope( def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) From 2d92f6d4274d0ae7ef0f8f798a6ec480c2a5cc58 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 15:16:17 +0900 Subject: [PATCH 68/90] fixed tvmscript roundtrip test --- python/tvm/script/special_stmt.py | 16 ++++++++++++++++ src/printer/tvmscript_printer.cc | 14 ++++++++++++-- .../python/unittest/test_tvmscript_roundtrip.py | 4 ++-- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 7eb938c58f96..befa37e19252 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -491,6 +491,22 @@ def var(dtype, span): super().__init__(var, def_symbol=True) +@register +class BufferVarDef(SpecialStmt): + """Special function for defining a Var""" + + def __init__(self): + def buffer_var(dtype, storage_scope, span): + assert isinstance( + self.node, ast.Assign + ), f"VarDef expected ast.Assign but got {type(self.node)}" + ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + v = te.var(self.node.lhs.id.name, ptr_type, span=span) + self.context.update_symbol(v.name, v, self.node) + + super().__init__(buffer_var, def_symbol=True) + + @register class EnvThread(SpecialStmt): """Bind a var to thread env""" diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4bbe17064c87..e855712617ca 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1013,8 +1014,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; - header_var << PrintDType(var->dtype) << ")"; + auto type = GetRef(var)->type_annotation; + if (auto* ptr_type = type.as()) { + auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << PrintDType(prim_type->dtype) << ", " + << Doc::StrLiteral(ptr_type->storage_scope) << ")"; + } else { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } } } doc << Doc::Indent(4, header_attr << header_var << header_buf << body); diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 164949552859..6c0e228e8e4c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -277,8 +277,8 @@ def mmult( } ) # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") + C_global = tir.buffer_var("float32", "global") + packedB = tir.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle") From 35fb917d4c181ed6e8a49a41121ae7ea7fc2ab33 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 15:37:19 +0900 Subject: [PATCH 69/90] fixed tir flatten buffer test --- python/tvm/script/special_stmt.py | 4 ++-- .../python/unittest/test_tir_transform_flatten_buffer.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index befa37e19252..6dbbb7354b26 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -493,13 +493,13 @@ def var(dtype, span): @register class BufferVarDef(SpecialStmt): - """Special function for defining a Var""" + """Special function for defining a Var of pointer type""" def __init__(self): def buffer_var(dtype, storage_scope, span): assert isinstance( self.node, ast.Assign - ), f"VarDef expected ast.Assign but got {type(self.node)}" + ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) v = te.var(self.node.lhs.id.name, ptr_type, span=span) self.context.update_symbol(v.name, v, self.node) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c997748649cd..6929a329ac0f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -35,7 +35,7 @@ def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32") + B = tir.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with tir.block() as []: tir.reads(A[i, j]) @@ -111,7 +111,7 @@ def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32 with tir.block([]): tir.reads(A[i, m]) tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32") + B = tir.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): with tir.block([]) as []: tir.reads(A[i, j]) @@ -190,8 +190,8 @@ def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) - B = tir.alloc_buffer((32,)) - C = tir.alloc_buffer((32,)) + B = tir.alloc_buffer((32,), scope="global") + C = tir.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 From 46647e93593597d434d920928c01c422ef3f3975 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:00:05 +0900 Subject: [PATCH 70/90] fixed test_tir_transform_hoist_if.py --- tests/python/unittest/test_tir_transform_hoist_if.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 252a187dbdc5..b111e2be75c7 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -636,7 +636,7 @@ def test_hoisting_block_scope_4(): def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") + data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") From db8eb9a88f420fb12dc2ca6f2e0cc42e0a3e1231 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:00:28 +0900 Subject: [PATCH 71/90] use storage scope global by default in aot_executor_codegen.cc --- src/relay/backend/aot_executor_codegen.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 9b495adbdea8..9b613bbff99f 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -722,7 +722,8 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { - te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + te::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global")); sids_table_[sid] = buffer_var; } } From 66c61ae5b9956c38738d2d16c8fbd2f8df69a0d1 Mon Sep 17 00:00:00 2001 From: masa Date: Tue, 6 Jul 2021 16:33:56 +0900 Subject: [PATCH 72/90] add missing default storage scope in create_primfunc.cc --- src/te/operation/create_primfunc.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 190892b2283f..a47556bac101 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -109,7 +109,7 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: } // Step 2. Declare buffer and update op2buffers - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint()); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; // Step 3. Add Buffer to root_alloc @@ -270,7 +270,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { const te::Tensor& tensor = op.output(0); // Check op is in op list ICHECK(info.IsArg(tensor)); - const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + const Buffer& buffer = + decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); info.tensor2buffers[tensor] = buffer; } else if (const auto* compute_op = op.as()) { // Case 2. ComputeOp (te.compute) From bd76606e547aa0f4af75556cb46ed7dcf79880fc Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 06:17:00 +0900 Subject: [PATCH 73/90] restore StorageInfo struct in llvm backend --- src/target/llvm/codegen_amdgpu.cc | 22 +++++++-------- src/target/llvm/codegen_cpu.cc | 6 ++--- src/target/llvm/codegen_llvm.cc | 45 ++++++++++++++++--------------- src/target/llvm/codegen_llvm.h | 9 +++++-- src/target/llvm/codegen_nvptx.cc | 20 +++++++------- 5 files changed, 54 insertions(+), 48 deletions(-) diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 01d2b2f7ad4d..4b182e17ed74 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -76,13 +76,13 @@ class CodeGenAMDGPU : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the AMD devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { @@ -91,11 +91,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } buf = alloca; @@ -109,11 +109,11 @@ class CodeGenAMDGPU : public CodeGenLLVM { llvm::GlobalVariable* global = new llvm::GlobalVariable( *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); - if (global->getAlignment() < static_cast(alignment)) { + if (global->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(alignment); + global->setAlignment(info.alignment); #endif } buf = global; diff --git a/src/target/llvm/codegen_cpu.cc b/src/target/llvm/codegen_cpu.cc index b9761355b208..ab96d6e69d14 100644 --- a/src/target/llvm/codegen_cpu.cc +++ b/src/target/llvm/codegen_cpu.cc @@ -463,9 +463,9 @@ void CodeGenCPU::CreateComputeScope(const AttrStmtNode* op) { } // Add alignment attribute if needed. #if TVM_LLVM_VERSION >= 50 - auto f = alloc_storage_alignment_.find(var.get()); - if (f != alloc_storage_alignment_.end()) { - unsigned align = f->second; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); fcompute->addParamAttr(idx, attr); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 545c94dddae3..9701a9f9ebb0 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -104,7 +104,7 @@ void CodeGenLLVM::AddFunction(const PrimFunc& f) { this->AddFunctionInternal(f, void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - alloc_storage_alignment_.clear(); + alloc_storage_info_.clear(); volatile_buf_.clear(); analyzer_.reset(new arith::Analyzer()); } @@ -165,9 +165,9 @@ void CodeGenLLVM::AddFunctionInternal(const PrimFunc& f, bool ret_void) { #if TVM_LLVM_VERSION >= 50 for (size_t i = 0; i < f->params.size(); ++i) { const Var& var = f->params[i]; - auto f = alloc_storage_alignment_.find(var.get()); - if (f != alloc_storage_alignment_.end()) { - unsigned align = f->second; + auto f = alloc_storage_info_.find(var.get()); + if (f != alloc_storage_info_.end()) { + unsigned align = f->second.alignment; if (align > 1) { auto attr = llvm::Attribute::get(*ctx_, llvm::Attribute::Alignment, align); function_->addParamAttr(i, attr); @@ -498,12 +498,12 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst, const VarNode* buffer, P void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExpr& index, int* p_alignment, int* p_native_bits) { int max_align_bits = t.bits(); - auto it = alloc_storage_alignment_.find(buf_var); - if (it != alloc_storage_alignment_.end()) { - const int alignment = it->second; + auto it = alloc_storage_info_.find(buf_var); + if (it != alloc_storage_info_.end()) { + const StorageInfo& info = it->second; *p_native_bits = NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); - max_align_bits = alignment * 8; + max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; } @@ -1354,25 +1354,25 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } - alignment = alloca->getAlignment(); + info.alignment = alloca->getAlignment(); buf = alloca; buf = builder_->CreatePointerCast( @@ -1394,10 +1394,10 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); - alloc_storage_alignment_[v] = static_cast(op->value.as()->value); - if (var_map_.count(v) && alloc_storage_alignment_[v] > 1) { + alloc_storage_info_[v].alignment = static_cast(op->value.as()->value); + if (var_map_.count(v) && alloc_storage_info_[v].alignment > 1) { builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), - alloc_storage_alignment_[v]); + alloc_storage_info_[v].alignment); } } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); @@ -1422,8 +1422,9 @@ void CodeGenLLVM::VisitStmt_(const LetStmtNode* op) { } var_map_[v] = MakeValue(op->value); analyzer_->Bind(op->var, op->value); - if (alloc_storage_alignment_.count(v) && alloc_storage_alignment_[v] > 1) { - builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), alloc_storage_alignment_[v]); + if (alloc_storage_info_.count(v) && alloc_storage_info_[v].alignment > 1) { + builder_->CreateAlignmentAssumption(*data_layout_, GetVarValue(v), + alloc_storage_info_[v].alignment); } this->VisitStmt(op->body); } diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index fb13ce42f897..810e59be7214 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -161,6 +161,11 @@ class CodeGenLLVM : public ExprFunctor, void VisitStmt_(const EvaluateNode* op) override; protected: + /*! \brief The storage information */ + struct StorageInfo { + /*! \brief The alignment of allocation */ + int alignment{0}; + }; /*! * \brief Execute falloca at the beginning of the * currrent function and obtain its return value. @@ -320,8 +325,8 @@ class CodeGenLLVM : public ExprFunctor, std::vector > link_modules_; /*! \brief native vector bits of current targetx*/ int native_vector_bits_{0}; - /*! \brief the alignment of allocation */ - std::unordered_map alloc_storage_alignment_; + /*! \brief the storage scope of allocation */ + std::unordered_map alloc_storage_info_; // The definition of local variable. std::unordered_map var_map_; // global strings diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index e8ae088ece32..18faf34143f0 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -51,13 +51,13 @@ class CodeGenNVPTX : public CodeGenLLVM { int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; - int& alignment = alloc_storage_alignment_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && alignment == 0) { - alignment = GetTempAllocaAlignment(op->dtype, constant_size); + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); } // maximum necessary alignment in the NV devices - if (alignment > 16) { - alignment = 16; + if (info.alignment > 16) { + info.alignment = 16; } auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { @@ -66,11 +66,11 @@ class CodeGenNVPTX : public CodeGenLLVM { llvm::AllocaInst* alloca = WithFunctionEntry([&]() { return builder_->CreateAlloca(DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); }); - if (alloca->getAlignment() < static_cast(alignment)) { + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(alignment); + alloca->setAlignment(info.alignment); #endif } buf = alloca; @@ -85,9 +85,9 @@ class CodeGenNVPTX : public CodeGenLLVM { *module_, type, false, llvm::GlobalValue::PrivateLinkage, nullptr, ".shared", nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(alignment); + global->setAlignment(info.alignment); #endif buf = global; } From 4cae209bf01384f4c1c70f4d8695e2da1072e161 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 06:43:04 +0900 Subject: [PATCH 74/90] UpdateStorageScope -> WithStorageScope --- include/tvm/tir/buffer.h | 2 +- src/tir/ir/buffer.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index f01158967bdd..d1e1c5b5103a 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -205,7 +205,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \return A string representing the storage scope of this buffer variable. */ TVM_DLL String GetStorageScope(Var buffer_var); -TVM_DLL Var UpdateStorageScope(Var buffer_var, String storage_scope); +TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); /*! * \brief Base node for data producers. diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 49da7c7f5630..670e503d6114 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -59,7 +59,7 @@ String GetStorageScope(Var buffer_var) { return ptr_type->storage_scope; } -Var UpdateStorageScope(Var buffer_var, String storage_scope) { +Var WithStorageScope(Var buffer_var, String storage_scope) { auto* ptr_type = buffer_var->type_annotation.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 68bf24abb847..17f07012c6f3 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -130,10 +130,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); } else { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "shared"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "shared"); } return stmt; } else { @@ -410,7 +410,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_var_remap_[repl->buffer_var.get()] = UpdateStorageScope(repl->buffer_var, "local"); + new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); } } From 01b94ccf8b21c44c967061a36a13621278aebdcf Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:07:06 +0900 Subject: [PATCH 75/90] fixed lower warp memory test --- src/tir/transforms/lower_thread_allreduce.cc | 4 +++ src/tir/transforms/lower_warp_memory.cc | 38 +++++++++++++++++--- 2 files changed, 38 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 17f07012c6f3..9488d1e99fcb 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -37,6 +37,8 @@ namespace tvm { namespace tir { +namespace { + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) @@ -81,6 +83,8 @@ class RemapStorageScope final : public StmtExprMutator { std::unordered_map new_var_remap_; }; +} // namespace + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936ca..8f7382a21153 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -44,6 +44,34 @@ namespace tvm { namespace tir { +namespace { + +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_var_remap) + : new_var_remap_(new_var_remap) {} + + Stmt VisitStmt_(const AttrStmtNode* op) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto it = new_var_remap_.find(buf); + if (it != new_var_remap_.end()) { + auto remapped = it->second; + auto new_scope = GetStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + } + return StmtMutator::VisitStmt_(op); + } + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace + // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -356,6 +384,8 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } + std::unordered_map new_var_remap_; + private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); @@ -374,9 +404,7 @@ class WarpMemoryRewriter : private StmtMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); + new_var_remap_[buf] = WithStorageScope(GetRef(buf), "local"); } } return StmtMutator::VisitStmt_(op); @@ -397,7 +425,9 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); - n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); + WarpMemoryRewriter warp_memory_rewriter(warp_size); + auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); + n->body = RemapStorageScope(warp_memory_rewriter.new_var_remap_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); From cbaa6e7c5dd49bfdf5fc321f29ded70e966d0b65 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:33:29 +0900 Subject: [PATCH 76/90] GetStorageScope -> GetPtrStorageScope --- include/tvm/tir/buffer.h | 2 +- src/target/llvm/codegen_amdgpu.cc | 2 +- src/target/llvm/codegen_llvm.cc | 2 +- src/target/llvm/codegen_nvptx.cc | 2 +- src/target/source/codegen_cuda.cc | 2 +- src/target/spirv/codegen_spirv.cc | 2 +- src/tir/ir/buffer.cc | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 4 ++-- src/tir/transforms/lower_warp_memory.cc | 2 +- src/tir/transforms/storage_access.cc | 2 +- src/tir/transforms/storage_flatten.cc | 2 +- src/tir/transforms/storage_rewrite.cc | 4 ++-- src/tir/transforms/thread_storage_sync.cc | 2 +- 13 files changed, 15 insertions(+), 15 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d1e1c5b5103a..73416d7fad2c 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -204,7 +204,7 @@ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Flo * \param buffer_var The input buffer variable. * \return A string representing the storage scope of this buffer variable. */ -TVM_DLL String GetStorageScope(Var buffer_var); +TVM_DLL String GetPtrStorageScope(Var buffer_var); TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); /*! diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 4b182e17ed74..9aec8f4e867b 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -84,7 +84,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 9701a9f9ebb0..bdae93b82aff 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -502,7 +502,7 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; *p_native_bits = - NativeVectorBits(runtime::StorageScope::Create(GetStorageScope(GetRef(buf_var)))); + NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 18faf34143f0..43ea0e6b7ae9 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -59,7 +59,7 @@ class CodeGenNVPTX : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 66b401c731e3..d7dcbec7ebe3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,7 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - std::string scope = GetStorageScope(op->buffer_var); + std::string scope = GetPtrStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 7c9dfcaf95e0..cc20b985e3c6 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -647,7 +647,7 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; spirv::Value buf; - auto storage_scope = runtime::StorageScope::Create(GetStorageScope(op->buffer_var)); + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 670e503d6114..abf75924b582 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -52,7 +52,7 @@ Buffer decl_buffer(Array shape, DataType dtype, String name, String st Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } -String GetStorageScope(Var buffer_var) { +String GetPtrStorageScope(Var buffer_var) { auto type = buffer_var->type_annotation; const auto* ptr_type = type.as(); ICHECK(ptr_type) << "The provided variable is not of pointer type"; diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9488d1e99fcb..2749e148a78d 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -54,8 +54,8 @@ class RemapStorageScope final : public StmtExprMutator { Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); - auto new_scope = GetStorageScope(remapped); - if (new_scope != GetStorageScope(op->buffer_var)) { + auto new_scope = GetPtrStorageScope(remapped); + if (new_scope != GetPtrStorageScope(op->buffer_var)) { Stmt body = StmtExprMutator::VisitStmt(op->body); if (new_scope == "shared") { // use volatile access to shared buffer. diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 8f7382a21153..49bc317a1e9c 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -58,7 +58,7 @@ class RemapStorageScope final : public StmtExprMutator { auto it = new_var_remap_.find(buf); if (it != new_var_remap_.end()) { auto remapped = it->second; - auto new_scope = GetStorageScope(remapped); + auto new_scope = GetPtrStorageScope(remapped); return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), StmtMutator::VisitStmt(op->body)); } diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 952758c4e5a7..9dae0006facd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -242,7 +242,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { if (buffer_var->type_annotation.as()) { - return StorageScope::Create(GetStorageScope(buffer_var)); + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } return StorageScope(); // global by default } diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index eca3bba83583..3eccf300639a 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -156,7 +156,7 @@ class StorageFlattener : public StmtExprMutator { } // deduce current storage scope. StorageScope skey; - std::string strkey = GetStorageScope(op->buffer->data); + std::string strkey = GetPtrStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index a18bc84604aa..613d02614b39 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -706,7 +706,7 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); const AllocateNode* alloc = alloc_info.at(var).alloc; - auto storage_scope = StorageScope::Create(GetStorageScope(GetRef(var))); + auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -923,7 +923,7 @@ class VectorAllocRewriter : public StmtExprMutator { // create a new buffer var DataType new_dtype = tvec[0]; Var new_buffer_var(op->buffer_var->name_hint, - PointerType(PrimType(new_dtype), GetStorageScope(op->buffer_var))); + PointerType(PrimType(new_dtype), GetPtrStorageScope(op->buffer_var))); // update the remap req. var_remap_.Set(op->buffer_var, new_buffer_var); return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 896224c0e956..ba033f7e97e5 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -286,7 +286,7 @@ class ThreadSyncInserter : public StmtExprMutator { // Get current storage scope. StorageScope GetScope(Var buffer_var) const { - return StorageScope::Create(GetStorageScope(buffer_var)); + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } // private functions. From 20b930b9cfc6f0a0633d0c1724295b1b7efc0670 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 07:34:18 +0900 Subject: [PATCH 77/90] Enable storage scope invariant check in AttrStmt constructor --- src/tir/ir/stmt.cc | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 18a5c2691005..1946faddb6bb 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -62,12 +62,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { // TODO(masahi): Enable this invariant check - // if (attr_key == attr::storage_scope) { - // const VarNode* buf = node.as(); - // ICHECK(buf); - // ICHECK(value.as()->value == GetStorageScope(GetRef(buf))) - // << value.as()->value << ", " << GetStorageScope(GetRef(buf)); - // } + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + ICHECK(buf); + auto attr_scope = value.as()->value; + auto buffer_scope = GetPtrStorageScope(GetRef(buf)); + ICHECK(attr_scope == buffer_scope) + << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope + << ", " << buffer_scope; + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); From a247d946e2d4b80ece991a87ca1621ed31350d68 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:02:48 +0900 Subject: [PATCH 78/90] remove GetPtrStorageScope and WithStorageScope from public header --- include/tvm/tir/buffer.h | 8 +------- src/target/source/codegen_c.h | 1 + src/target/spirv/codegen_spirv.cc | 2 +- src/tir/ir/buffer.cc | 14 -------------- src/tir/ir/stmt.cc | 8 ++++---- src/tir/transforms/ir_utils.cc | 6 ++++++ src/tir/transforms/ir_utils.h | 6 ++++++ src/tir/transforms/lower_thread_allreduce.cc | 7 +++++++ src/tir/transforms/lower_warp_memory.cc | 8 ++++++++ src/tir/transforms/thread_storage_sync.cc | 1 - 10 files changed, 34 insertions(+), 27 deletions(-) diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 73416d7fad2c..d26221d6a4ff 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -199,13 +199,7 @@ class Buffer : public ObjectRef { TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), String name = "buffer", String storage_scope = "", Span span = Span()); -/*! - * \brief Return the storage scope associated with a buffer variable. - * \param buffer_var The input buffer variable. - * \return A string representing the storage scope of this buffer variable. - */ -TVM_DLL String GetPtrStorageScope(Var buffer_var); -TVM_DLL Var WithStorageScope(Var buffer_var, String storage_scope); + /*! * \brief Base node for data producers. diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index ae451f39f89b..834c57ac10fd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -39,6 +39,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "codegen_source_base.h" namespace tvm { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index cc20b985e3c6..c1fa921d4507 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -23,7 +23,6 @@ */ #include "codegen_spirv.h" -#include #include #include #include @@ -33,6 +32,7 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../runtime/vulkan/vulkan_shader.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { namespace codegen { diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index abf75924b582..e2fcf89d8966 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -52,20 +52,6 @@ Buffer decl_buffer(Array shape, DataType dtype, String name, String st Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } -String GetPtrStorageScope(Var buffer_var) { - auto type = buffer_var->type_annotation; - const auto* ptr_type = type.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return ptr_type->storage_scope; -} - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - // Split the given expression w.r.t the add operator inline std::vector ExprSplitAddition(const PrimExpr& expr) { using namespace tir; diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 1946faddb6bb..e39e9b608474 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,15 +61,15 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { - // TODO(masahi): Enable this invariant check if (attr_key == attr::storage_scope) { const VarNode* buf = node.as(); ICHECK(buf); + const auto* ptr_type = buf->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; auto attr_scope = value.as()->value; - auto buffer_scope = GetPtrStorageScope(GetRef(buf)); - ICHECK(attr_scope == buffer_scope) + ICHECK(attr_scope == ptr_type->storage_scope) << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope - << ", " << buffer_scope; + << ", " << ptr_type->storage_scope; } auto n = make_object(); n->node = node; diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..f7ece25d3fcd 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -201,5 +201,11 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } +String GetPtrStorageScope(Var buffer_var) { + const auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 906ff8a38b6c..b5a154b707af 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -191,6 +191,12 @@ inline PrimExpr StackAlloca(std::string type, size_t num) { */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +String GetPtrStorageScope(Var buffer_var); } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 2749e148a78d..ebd3ba88d7f7 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -39,6 +39,13 @@ namespace tir { namespace { +Var WithStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 49bc317a1e9c..3962472e57e6 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -40,12 +40,20 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" namespace tvm { namespace tir { namespace { +Var WithStorageScope(Var buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_var_remap) diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index ba033f7e97e5..35e4563b8f58 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -22,7 +22,6 @@ */ #include #include -#include #include #include #include From 0d8c9bca078ccd24177b61a4ba0076ea28429f71 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:17:18 +0900 Subject: [PATCH 79/90] move RemapStorageScope to its own file --- include/tvm/tir/buffer.h | 2 - src/tir/transforms/lower_warp_memory.cc | 42 ++--------- .../transforms/remap_pointer_storage_scope.cc | 69 +++++++++++++++++++ .../transforms/remap_pointer_storage_scope.h | 44 ++++++++++++ 4 files changed, 117 insertions(+), 40 deletions(-) create mode 100644 src/tir/transforms/remap_pointer_storage_scope.cc create mode 100644 src/tir/transforms/remap_pointer_storage_scope.h diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index d26221d6a4ff..2507262c087f 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -199,8 +199,6 @@ class Buffer : public ObjectRef { TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), String name = "buffer", String storage_scope = "", Span span = Span()); - - /*! * \brief Base node for data producers. * diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index 3962472e57e6..bc43d4da5311 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -41,45 +41,11 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "remap_pointer_storage_scope.h" namespace tvm { namespace tir { -namespace { - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - -class RemapStorageScope final : public StmtExprMutator { - public: - explicit RemapStorageScope(const std::unordered_map& new_var_remap) - : new_var_remap_(new_var_remap) {} - - Stmt VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - auto it = new_var_remap_.find(buf); - if (it != new_var_remap_.end()) { - auto remapped = it->second; - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } - } - return StmtMutator::VisitStmt_(op); - } - - private: - std::unordered_map new_var_remap_; -}; - -} // namespace - // Rewrite Rule // // There is no special warp memory in most GPUs. @@ -392,7 +358,7 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } - std::unordered_map new_var_remap_; + std::unordered_map new_storage_scopes_; private: Stmt VisitStmt_(const AllocateNode* op) { @@ -412,7 +378,7 @@ class WarpMemoryRewriter : private StmtMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - new_var_remap_[buf] = WithStorageScope(GetRef(buf), "local"); + new_storage_scopes_[buf] = "local"; } } return StmtMutator::VisitStmt_(op); @@ -435,7 +401,7 @@ Pass LowerWarpMemory() { int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); - n->body = RemapStorageScope(warp_memory_rewriter.new_var_remap_)(stmt); + n->body = RemapStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc new file mode 100644 index 000000000000..70250faeca1f --- /dev/null +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * TODO + * \file remap_pointer_storage_scope.cc + */ +#include "remap_pointer_storage_scope.h" + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + +RemapStorageScope::RemapStorageScope( + const std::unordered_map& new_storage_scopes) { + for (auto kv : new_storage_scopes) { + new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); + } +} + +Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { + using runtime::StorageScope; + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto it = new_var_remap_.find(buf); + if (it != new_var_remap_.end()) { + auto remapped = it->second; + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + } + return StmtMutator::VisitStmt_(op); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h new file mode 100644 index 000000000000..051f757ddc0c --- /dev/null +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * TODO + * \file remap_pointer_storage_scope.h + */ +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class RemapStorageScope final : public StmtExprMutator { + public: + explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); + + virtual Stmt VisitStmt_(const AttrStmtNode* op); + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace tir +} // namespace tvm From c2953385358ae04f7a4ee982a782c28642175e19 Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:27:04 +0900 Subject: [PATCH 80/90] add more method to RemapStorageScope --- .../transforms/remap_pointer_storage_scope.cc | 37 +++++++++++++++---- .../transforms/remap_pointer_storage_scope.h | 6 ++- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc index 70250faeca1f..61b22f078931 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -50,20 +50,43 @@ RemapStorageScope::RemapStorageScope( } } +PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { + auto it = new_var_remap_.find(op); + if (it == new_var_remap_.end()) { + return GetRef(op); + } + return it->second; +} + +PrimExpr RemapStorageScope::VisitExpr_(const LoadNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); +} + Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); - auto it = new_var_remap_.find(buf); - if (it != new_var_remap_.end()) { - auto remapped = it->second; - auto new_scope = GetPtrStorageScope(remapped); - return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), - StmtMutator::VisitStmt(op->body)); - } + auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); } return StmtMutator::VisitStmt_(op); } +Stmt RemapStorageScope::VisitStmt_(const AllocateNode* op) { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + StmtExprMutator::VisitStmt(op->body)); +} + +Stmt RemapStorageScope::VisitStmt_(const StoreNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h index 051f757ddc0c..9689effd11fe 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -34,7 +34,11 @@ class RemapStorageScope final : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); - virtual Stmt VisitStmt_(const AttrStmtNode* op); + virtual PrimExpr VisitExpr_(const VarNode*); + virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AttrStmtNode*); + virtual Stmt VisitStmt_(const AllocateNode*); + virtual Stmt VisitStmt_(const StoreNode*); private: std::unordered_map new_var_remap_; From ebaceb98603917ace91797093eea1adc50e1947c Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 08:34:32 +0900 Subject: [PATCH 81/90] update lower_thread_allreduce to use RemapStorageScope --- src/tir/transforms/lower_thread_allreduce.cc | 52 ++++--------------- .../transforms/remap_pointer_storage_scope.cc | 2 +- .../transforms/remap_pointer_storage_scope.h | 8 ++- 3 files changed, 17 insertions(+), 45 deletions(-) diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index ebd3ba88d7f7..d4c1c6aea594 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,31 +33,16 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "remap_pointer_storage_scope.h" namespace tvm { namespace tir { -namespace { - -Var WithStorageScope(Var buffer_var, String storage_scope) { - auto* ptr_type = buffer_var->type_annotation.as(); - ICHECK(ptr_type) << "The provided variable is not of pointer type"; - return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), - buffer_var->span); -} - -class RemapStorageScope final : public StmtExprMutator { +class RemapStorageScopeAllReduce final : public RemapStorageScope { public: - explicit RemapStorageScope(const std::unordered_map& new_var_remap) - : new_var_remap_(new_var_remap) {} - - PrimExpr VisitExpr_(const VarNode* op) final { - auto it = new_var_remap_.find(op); - if (it == new_var_remap_.end()) { - return GetRef(op); - } - return it->second; - } + explicit RemapStorageScopeAllReduce( + const std::unordered_map& new_storage_scopes) + : RemapStorageScope(new_storage_scopes) {} Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); @@ -73,25 +58,8 @@ class RemapStorageScope final : public StmtExprMutator { } return StmtExprMutator::VisitStmt_(op); } - - Stmt VisitStmt_(const StoreNode* op) final { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), - StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); - } - - PrimExpr VisitExpr_(const LoadNode* op) final { - auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); - return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), - StmtExprMutator::VisitExpr(op->predicate)); - } - - private: - std::unordered_map new_var_remap_; }; -} // namespace - class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -141,10 +109,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "shared"); + new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; } else { @@ -161,7 +129,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } - std::unordered_map new_var_remap_; + std::unordered_map new_storage_scopes_; private: // Thread entry @@ -421,7 +389,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - new_var_remap_[repl->buffer_var.get()] = WithStorageScope(repl->buffer_var, "local"); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } } @@ -647,7 +615,7 @@ Pass LowerThreadAllreduce() { const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); auto reduce_body = thread_all_reduce(n->body); - n->body = RemapStorageScope(thread_all_reduce.new_var_remap_)(reduce_body); + n->body = RemapStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/remap_pointer_storage_scope.cc index 61b22f078931..8225cddeb094 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/remap_pointer_storage_scope.cc @@ -28,7 +28,7 @@ #include #include -#include +#include #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/remap_pointer_storage_scope.h index 9689effd11fe..0756810d8c9e 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/remap_pointer_storage_scope.h @@ -21,16 +21,19 @@ * TODO * \file remap_pointer_storage_scope.h */ +#ifndef TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ + #include #include #include -#include +#include namespace tvm { namespace tir { -class RemapStorageScope final : public StmtExprMutator { +class RemapStorageScope : public StmtExprMutator { public: explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); @@ -46,3 +49,4 @@ class RemapStorageScope final : public StmtExprMutator { } // namespace tir } // namespace tvm +#endif // TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ From 0bf50813fc61de630b78496e18b28f4cffa7cf7d Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 09:26:48 +0900 Subject: [PATCH 82/90] RemapStorageScope -> UpdatePointerStorageScope --- python/tvm/tir/ir_builder.py | 2 +- src/tir/transforms/lower_thread_allreduce.cc | 11 ++++++----- src/tir/transforms/lower_warp_memory.cc | 4 ++-- ...cope.cc => update_pointer_storage_scope.cc} | 18 +++++++++--------- ..._scope.h => update_pointer_storage_scope.h} | 15 ++++++++------- 5 files changed, 26 insertions(+), 24 deletions(-) rename src/tir/transforms/{remap_pointer_storage_scope.cc => update_pointer_storage_scope.cc} (84%) rename src/tir/transforms/{remap_pointer_storage_scope.h => update_pointer_storage_scope.h} (74%) diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 1573d96e7d0d..03c1339c4d38 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -436,7 +436,7 @@ def pointer(self, content_type, name="ptr", scope=""): The name of the pointer. scope : str, optional - The scope of the buffer. + The scope of the pointer. Returns ------- diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index d4c1c6aea594..25a2f4e060dd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,16 +33,16 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { -class RemapStorageScopeAllReduce final : public RemapStorageScope { +class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { public: - explicit RemapStorageScopeAllReduce( + explicit UpdatePointerStorageScopeAllReduce( const std::unordered_map& new_storage_scopes) - : RemapStorageScope(new_storage_scopes) {} + : UpdatePointerStorageScope(new_storage_scopes) {} Stmt VisitStmt_(const AllocateNode* op) final { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); @@ -615,7 +615,8 @@ Pass LowerThreadAllreduce() { const TargetNode* target_node = target.as(); ThreadAllreduceBuilder thread_all_reduce(target_node); auto reduce_body = thread_all_reduce(n->body); - n->body = RemapStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); + n->body = + UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index bc43d4da5311..060b02c3d137 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -41,7 +41,7 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { @@ -401,7 +401,7 @@ Pass LowerWarpMemory() { int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); WarpMemoryRewriter warp_memory_rewriter(warp_size); auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); - n->body = RemapStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); + n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/remap_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc similarity index 84% rename from src/tir/transforms/remap_pointer_storage_scope.cc rename to src/tir/transforms/update_pointer_storage_scope.cc index 8225cddeb094..ae72e7f947cd 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -18,10 +18,10 @@ */ /*! - * TODO - * \file remap_pointer_storage_scope.cc + * \file update_pointer_storage_scope.cc + * \brief A pass to update storage scopes for buffer variables. */ -#include "remap_pointer_storage_scope.h" +#include "update_pointer_storage_scope.h" #include #include @@ -43,14 +43,14 @@ Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { buffer_var->span); } -RemapStorageScope::RemapStorageScope( +UpdatePointerStorageScope::UpdatePointerStorageScope( const std::unordered_map& new_storage_scopes) { for (auto kv : new_storage_scopes) { new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); } } -PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { +PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { auto it = new_var_remap_.find(op); if (it == new_var_remap_.end()) { return GetRef(op); @@ -58,13 +58,13 @@ PrimExpr RemapStorageScope::VisitExpr_(const VarNode* op) { return it->second; } -PrimExpr RemapStorageScope::VisitExpr_(const LoadNode* op) { +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); } -Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); @@ -76,13 +76,13 @@ Stmt RemapStorageScope::VisitStmt_(const AttrStmtNode* op) { return StmtMutator::VisitStmt_(op); } -Stmt RemapStorageScope::VisitStmt_(const AllocateNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), StmtExprMutator::VisitStmt(op->body)); } -Stmt RemapStorageScope::VisitStmt_(const StoreNode* op) { +Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); diff --git a/src/tir/transforms/remap_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h similarity index 74% rename from src/tir/transforms/remap_pointer_storage_scope.h rename to src/tir/transforms/update_pointer_storage_scope.h index 0756810d8c9e..481536a45b27 100644 --- a/src/tir/transforms/remap_pointer_storage_scope.h +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -18,11 +18,11 @@ */ /*! - * TODO - * \file remap_pointer_storage_scope.h + * \file update_pointer_storage_scope.h + * \brief A pass to update storage scopes for buffer variables. */ -#ifndef TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ -#define TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#ifndef TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ #include #include @@ -33,9 +33,10 @@ namespace tvm { namespace tir { -class RemapStorageScope : public StmtExprMutator { +class UpdatePointerStorageScope : public StmtExprMutator { public: - explicit RemapStorageScope(const std::unordered_map& new_storage_scopes); + explicit UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes); virtual PrimExpr VisitExpr_(const VarNode*); virtual PrimExpr VisitExpr_(const LoadNode*); @@ -49,4 +50,4 @@ class RemapStorageScope : public StmtExprMutator { } // namespace tir } // namespace tvm -#endif // TVM_TIR_TRANSFORMS_REMAP_POINTER_STORAGE_SCOPE_H_ +#endif // TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ From c03c360cc4a96d661b4d635ada1790fc32f23746 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 14:57:50 +0900 Subject: [PATCH 83/90] remove realize_scope from hybrid script --- python/tvm/te/hybrid/parser.py | 3 +-- src/contrib/hybrid/codegen_hybrid.cc | 9 ++------- src/contrib/hybrid/codegen_hybrid.h | 2 -- tests/python/unittest/test_te_hybrid_script.py | 4 +--- 4 files changed, 4 insertions(+), 14 deletions(-) diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) - body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7522f20523c8..54edbaee35cd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); - alloc_storage_scope_[v] = op->value.as()->value; - PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { auto tensor = Downcast(op->producer); - ICHECK(alloc_storage_scope_.count(tensor->op)); - if (!alloc_storage_scope_[tensor->op].empty()) { + if (!op->storage_scope.empty()) { PrintIndent(); stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { stream << "), '"; PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[tensor->op] << "')\n"; + stream << op->storage_scope << "')\n"; } PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index b01ca2763e28..47c13f73022f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor, * \param tensor The tensor to allocate a name. */ std::string GetTensorID(const Tensor& tensor); - /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 30b96546f991..e9626e7f31b4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -189,9 +189,7 @@ def fanout(n, a): assert ir.min.value == 0 assert tvm.ir.structural_equal(ir.extent, n - 3) # Check loopbody - ibody = ir.body - assert isinstance(ibody, tvm.tir.AttrStmt) - abody = ibody.body + abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 From 21d41343e208dfafe407c96dd1cbbd9bccc122a6 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:23:53 +0900 Subject: [PATCH 84/90] removed realize_scope in schedule_ops --- docs/dev/inferbound.rst | 2 -- src/te/schedule/schedule_ops.cc | 5 +---- tests/python/unittest/test_te_schedule_tensorize.py | 4 ++-- tests/python/unittest/test_te_tensor.py | 2 +- tests/python/unittest/test_tir_transform_loop_partition.py | 4 ++-- 5 files changed, 6 insertions(+), 11 deletions(-) diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index 010d0d42d37e..28e034dc44cb 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser :: - // attr [compute(D, 0x2c070b0)] realize_scope = "" realize D([0, 4], [0, 5], [0, 16]) { produce D { for (di, 0, 4) { for (dj, 0, 5) { for (dk, 0, 16) { - // attr [compute(C, 0x2c29990)] realize_scope = "" realize C([dj, 1], [dk, 1]) { produce C { for (i, 0, 1) { diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 21edd2f94b20..9faff741372b 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline, s->scope); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index e2c2f7f7e0e5..ae5e7051bfba 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -379,8 +379,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body, tvm.tir.For) - assert stmt.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body, tvm.tir.For) + assert stmt.body.loop_var.name == C.op.axis[0].var.name if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index ed4a21397885..2931925965b7 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -309,7 +309,7 @@ def get_B1_realize(x): ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) - assert stmt.node == C.op and len(ret) == 1 + assert stmt.producer == C and len(ret) == 1 def test_tensor_inputs(): diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 9e8848083908..6194024748e0 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -156,7 +156,7 @@ def test_thread_axis(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) From 0ff503b401179b83a95ef8b68dfc547946839c14 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:43:51 +0900 Subject: [PATCH 85/90] remove realize_scope from schedule_postproc_to_primfunc --- src/te/schedule/schedule_postproc_to_primfunc.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 8e6cc131b76e..2063fc7cad6a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { From 086d891f99f5c16fc2187f7bd60c7293da49305b Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 15:52:57 +0900 Subject: [PATCH 86/90] remove remaining realize_scope usage from schedule_ops.cc --- src/te/schedule/schedule_ops.cc | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 9faff741372b..825092d20ac0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -172,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { From e60ad8d321b9562f8e52b1f162dc3769733ed54e Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Wed, 7 Jul 2021 16:06:21 +0900 Subject: [PATCH 87/90] remove realize_scope usage from storage_flatten.cc --- src/tir/transforms/storage_flatten.cc | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 3eccf300639a..0db86130a8da 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,10 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); From cb697d846cbfedf41bcadb08c0eed3f734fe2efc Mon Sep 17 00:00:00 2001 From: masa Date: Wed, 7 Jul 2021 17:56:53 +0900 Subject: [PATCH 88/90] fixed test_tir_transform_lower_warp_memory.py following realize_scope removal --- tests/python/unittest/test_tir_transform_lower_warp_memory.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ef474c15cfbb..f3baff120cf6 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -72,8 +72,8 @@ def test_lower_warp_memory_correct_indices(): bounds = tvm.te.schedule.InferBound(s) ir = tvm.te.schedule.ScheduleOps(s, bounds) - inner_func = ir.body.body.body.body - store_A_warp = inner_func.body.seq[0].body.body + inner_func = ir.body.body.body + store_A_warp = inner_func.seq[0].body.body indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp From f459dbfd3a09e5bcba25fd6838cf878efb961fa7 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Sat, 10 Jul 2021 16:19:13 +0900 Subject: [PATCH 89/90] Address comments --- src/runtime/thread_storage_scope.h | 2 +- src/tir/transforms/update_pointer_storage_scope.cc | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index d93a1f130bae..9d140aedd810 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -118,7 +118,7 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s == "") { + if (s.empty()) { r.rank = StorageRank::kGlobal; } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc index ae72e7f947cd..0ae02fec9f95 100644 --- a/src/tir/transforms/update_pointer_storage_scope.cc +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -45,7 +45,7 @@ Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { UpdatePointerStorageScope::UpdatePointerStorageScope( const std::unordered_map& new_storage_scopes) { - for (auto kv : new_storage_scopes) { + for (auto& kv : new_storage_scopes) { new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); } } @@ -65,7 +65,6 @@ PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { } Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { - using runtime::StorageScope; if (op->attr_key == attr::storage_scope) { const VarNode* buf = op->node.as(); auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); From 3e5e0efe2dfde96f9b3a162ff66524ff79b62723 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Mon, 12 Jul 2021 11:24:07 +0900 Subject: [PATCH 90/90] Remove blank line diff --- src/te/operation/cross_thread_reduction.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index 0c20328f02b7..f844090ca6f5 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -1,4 +1,3 @@ - /* * Licensed to the Apache Software Foundation (ASF) under one * or more contributor license agreements. See the NOTICE file