diff --git a/docs/dev/inferbound.rst b/docs/dev/inferbound.rst index 010d0d42d37e..28e034dc44cb 100644 --- a/docs/dev/inferbound.rst +++ b/docs/dev/inferbound.rst @@ -447,13 +447,11 @@ Here is the IR after ScheduleOps (note that loops with extent 1 have been preser :: - // attr [compute(D, 0x2c070b0)] realize_scope = "" realize D([0, 4], [0, 5], [0, 16]) { produce D { for (di, 0, 4) { for (dj, 0, 5) { for (dk, 0, 16) { - // attr [compute(C, 0x2c29990)] realize_scope = "" realize C([dj, 1], [dk, 1]) { produce C { for (i, 0, 1) { diff --git a/include/tvm/te/operation.h b/include/tvm/te/operation.h index 27e48999a7d1..13f39317dbe4 100644 --- a/include/tvm/te/operation.h +++ b/include/tvm/te/operation.h @@ -125,11 +125,12 @@ class TVM_DLL OperationNode : public Object { * \param stage the op's stage. * \param realize_map The realization domain map of the operators. * \param body The body that is going to get + * \param storage_scope The storage scope associated with this realization * \return A realization statement that wraps body. */ virtual Stmt BuildRealize(const Stage& stage, - const std::unordered_map& realize_map, - const Stmt& body) const = 0; + const std::unordered_map& realize_map, const Stmt& body, + String storage_scope = "") const = 0; /*! * \brief Build the statement that provide the output tensors. * \param stage The schedule stage of the op. @@ -168,7 +169,7 @@ class PlaceholderOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -212,7 +213,7 @@ class TVM_DLL BaseComputeOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; virtual size_t num_schedulable_dims() const = 0; static constexpr const char* _type_key = "BaseComputeOp"; @@ -370,7 +371,7 @@ class ScanOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -433,7 +434,7 @@ class ExternOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; @@ -498,7 +499,7 @@ class HybridOpNode : public OperationNode { void GatherBound(const Operation& self, const std::unordered_map& tensor_dom, std::unordered_map* out_dom_map) const final; Stmt BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const final; + const Stmt& body, String storage_scope = "") const final; Stmt BuildProvide(const Stage& stage, const std::unordered_map& dom_map, bool debug_keep_trivial_loop) const final; diff --git a/include/tvm/tir/buffer.h b/include/tvm/tir/buffer.h index 017f4f7052b1..2507262c087f 100644 --- a/include/tvm/tir/buffer.h +++ b/include/tvm/tir/buffer.h @@ -191,12 +191,13 @@ class Buffer : public ObjectRef { * \param shape The shape of the buffer, * \param dtype The content data type. * \param name The name of the buffer + * \param storage_scope The storage scope associated with this buffer * \param span The location of this object in the source code. * \return The created buffer. * \sa Buffer for complete constructor. */ TVM_DLL Buffer decl_buffer(Array shape, DataType dtype = DataType::Float(32), - String name = "buffer", Span span = Span()); + String name = "buffer", String storage_scope = "", Span span = Span()); /*! * \brief Base node for data producers. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index cc10c218c8ff..9997a4d95694 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -464,18 +464,22 @@ class ProducerRealizeNode : public StmtNode { PrimExpr condition; /*! \brief The body of realization. */ Stmt body; + /*! \brief The storage scope associated with this realization. */ + String storage_scope; void VisitAttrs(AttrVisitor* v) { v->Visit("producer", &producer); v->Visit("bounds", &bounds); v->Visit("condition", &condition); v->Visit("body", &body); + v->Visit("storage_scope", &storage_scope); v->Visit("span", &span); } bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body); + equal(condition, other->condition) && equal(body, other->body) && + equal(storage_scope, other->storage_scope); } void SHashReduce(SHashReducer hash_reduce) const { @@ -483,6 +487,7 @@ class ProducerRealizeNode : public StmtNode { hash_reduce(bounds); hash_reduce(condition); hash_reduce(body); + hash_reduce(storage_scope); } static constexpr const char* _type_key = "tir.ProducerRealize"; @@ -496,7 +501,7 @@ class ProducerRealizeNode : public StmtNode { class ProducerRealize : public Stmt { public: TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span = Span()); + String storage_scope = "", Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); }; diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index a23401d926e9..d07209485bd4 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -140,7 +140,7 @@ def enter_scope( def setup_buffer_var(extents, dtype, scope, condition=True, span: Span = None): """Setup buffer var for a given type.""" - buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype)) + buffer_ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), scope) self.buffer_var = tvm.tir.Var(name, buffer_ptr_type, span) setup_buffer_var(*arg_list, span=tvm_span_from_synr(var_span)) diff --git a/python/tvm/script/special_stmt.py b/python/tvm/script/special_stmt.py index 7eb938c58f96..e40bc2fda6eb 100644 --- a/python/tvm/script/special_stmt.py +++ b/python/tvm/script/special_stmt.py @@ -491,6 +491,22 @@ def var(dtype, span): super().__init__(var, def_symbol=True) +@register +class BufferVarDef(SpecialStmt): + """Special function for defining a variable of pointer type""" + + def __init__(self): + def buffer_var(dtype, storage_scope, span): + assert isinstance( + self.node, ast.Assign + ), f"BufferVarDef expected ast.Assign but got {type(self.node)}" + ptr_type = tvm.ir.PointerType(tvm.ir.PrimType(dtype), storage_scope) + v = te.var(self.node.lhs.id.name, ptr_type, span=span) + self.context.update_symbol(v.name, v, self.node) + + super().__init__(buffer_var, def_symbol=True) + + @register class EnvThread(SpecialStmt): """Bind a var to thread env""" diff --git a/python/tvm/te/hybrid/parser.py b/python/tvm/te/hybrid/parser.py index 7bb85e3da83c..442aeb6f1027 100644 --- a/python/tvm/te/hybrid/parser.py +++ b/python/tvm/te/hybrid/parser.py @@ -207,8 +207,7 @@ def wrap_up_realize(self, node, body): _domain = [Range.from_min_extent(0, i) for i in _buf.shape] _dtype = _buf.dtype _true = tvm.runtime.convert(True) - body = tvm.tir.ProducerRealize(_buf, _domain, _true, body) - body = tvm.tir.AttrStmt(_buf.op, "realize_scope", tvm.runtime.convert(_scope), body) + body = tvm.tir.ProducerRealize(_buf, _domain, _true, body, tvm.runtime.convert(_scope)) for elem in to_pop: self.symbols.pop(elem) diff --git a/python/tvm/tir/buffer.py b/python/tvm/tir/buffer.py index d9dc9e58acd6..086d93f49a2b 100644 --- a/python/tvm/tir/buffer.py +++ b/python/tvm/tir/buffer.py @@ -252,7 +252,7 @@ def decl_buffer( # Bool is represented as uint1 in the IR, but stored as int8 storage_type = PrimType(dtype) storage_type = PrimType("int8") if storage_type.dtype == "bool" else storage_type - data = Var(name, PointerType(storage_type), span) + data = Var(name, PointerType(storage_type, scope), span) return _ffi_api.Buffer( # type: ignore data, dtype, diff --git a/python/tvm/tir/ir_builder.py b/python/tvm/tir/ir_builder.py index 4796cb0d9549..35932540fe68 100644 --- a/python/tvm/tir/ir_builder.py +++ b/python/tvm/tir/ir_builder.py @@ -394,7 +394,7 @@ def let(self, var_name, value): self.emit(lambda x: _stmt.LetStmt(var, value, x)) return var - def allocate(self, dtype, shape, name="buf", scope=None): + def allocate(self, dtype, shape, name="buf", scope=""): """Create a allocate statement. Parameters @@ -416,7 +416,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): buffer : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, PointerType(PrimType(dtype))) + buffer_var = _expr.Var(name, PointerType(PrimType(dtype), scope)) if not isinstance(shape, (list, tuple, _container.Array)): shape = [shape] if scope: @@ -424,7 +424,7 @@ def allocate(self, dtype, shape, name="buf", scope=None): self.emit(lambda x: _stmt.Allocate(buffer_var, dtype, shape, const(1, dtype="uint1"), x)) return BufferVar(self, buffer_var, shape, dtype) - def pointer(self, content_type, name="ptr"): + def pointer(self, content_type, name="ptr", scope=""): """Create pointer variable with content type. Parameters @@ -435,12 +435,15 @@ def pointer(self, content_type, name="ptr"): name : str, optional The name of the pointer. + scope : str, optional + The scope of the pointer. + Returns ------- ptr : BufferVar The buffer var representing the buffer. """ - buffer_var = _expr.Var(name, dtype="handle") + buffer_var = _expr.Var(name, PointerType(PrimType(content_type), scope)) return BufferVar(self, buffer_var, None, content_type) def buffer_ptr(self, buf, shape=None): diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index da737cbd3c60..d57077f08b52 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -374,13 +374,22 @@ class ProducerRealize(Stmt): body : Stmt The realize body + storage_scope : str + The storage scope associated with this realization + span : Optional[Span] The location of this itervar in the source code. """ - def __init__(self, producer, bounds, condition, body, span=None): + def __init__(self, producer, bounds, condition, body, storage_scope="", span=None): self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, producer, bounds, condition, body, span # type: ignore + _ffi_api.ProducerRealize, + producer, + bounds, + condition, + body, + storage_scope, + span, # type: ignore ) diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index 7522f20523c8..54edbaee35cd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -315,10 +315,6 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { indent_ += tab_; PrintStmt(op->body); indent_ -= tab_; - } else if (op->attr_key == tir::attr::realize_scope) { - auto v = Downcast(op->node); - alloc_storage_scope_[v] = op->value.as()->value; - PrintStmt(op->body); } else { // For now we ignore the unsupported AttrStmt PrintStmt(op->body); @@ -327,8 +323,7 @@ void CodeGenHybrid::VisitStmt_(const AttrStmtNode* op) { void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { auto tensor = Downcast(op->producer); - ICHECK(alloc_storage_scope_.count(tensor->op)); - if (!alloc_storage_scope_[tensor->op].empty()) { + if (!op->storage_scope.empty()) { PrintIndent(); stream << GetTensorID(tensor) << " = allocate(("; for (size_t i = 0; i < op->bounds.size(); ++i) { @@ -339,7 +334,7 @@ void CodeGenHybrid::VisitStmt_(const ProducerRealizeNode* op) { stream << "), '"; PrintType(tensor->dtype, stream); stream << "', '"; - stream << alloc_storage_scope_[tensor->op] << "')\n"; + stream << op->storage_scope << "')\n"; } PrintStmt(op->body); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index b01ca2763e28..47c13f73022f 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -168,8 +168,6 @@ class CodeGenHybrid : public ExprFunctor, * \param tensor The tensor to allocate a name. */ std::string GetTensorID(const Tensor& tensor); - /*! \brief the storage scope of allocation */ - std::map alloc_storage_scope_; }; } // namespace contrib diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index 4bbe17064c87..e855712617ca 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -26,6 +26,7 @@ #include #include #include +#include #include #include #include @@ -1013,8 +1014,17 @@ Doc TVMScriptPrinter::PrintPrimFunc(const PrimFunc& primFunc) { return memo_var_[GetRef(a)].str() < memo_var_[GetRef(b)].str(); }); for (const auto& var : vars) { - header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; - header_var << PrintDType(var->dtype) << ")"; + auto type = GetRef(var)->type_annotation; + if (auto* ptr_type = type.as()) { + auto* prim_type = ptr_type->element_type.as(); + ICHECK(prim_type); + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.buffer_var("; + header_var << PrintDType(prim_type->dtype) << ", " + << Doc::StrLiteral(ptr_type->storage_scope) << ")"; + } else { + header_var << Doc::NewLine() << Print(GetRef(var)) << " = tir.var("; + header_var << PrintDType(var->dtype) << ")"; + } } } doc << Doc::Indent(4, header_attr << header_var << header_buf << body); diff --git a/src/relay/backend/aot_executor_codegen.cc b/src/relay/backend/aot_executor_codegen.cc index 84c17b53c83e..4df38b9449ae 100644 --- a/src/relay/backend/aot_executor_codegen.cc +++ b/src/relay/backend/aot_executor_codegen.cc @@ -722,7 +722,8 @@ class AOTExecutorCodegen : public ExprVisitor { // Define the storage allocator ids for (auto kv : storage_device_map_) { for (auto sid : kv.second->storage_ids) { - te::Var buffer_var(MakeString("sid_", sid), PointerType(PrimType(DataType::Int(8)))); + te::Var buffer_var(MakeString("sid_", sid), + PointerType(PrimType(DataType::Int(8)), "global")); sids_table_[sid] = buffer_var; } } diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h index c0393600b60c..9d140aedd810 100644 --- a/src/runtime/thread_storage_scope.h +++ b/src/runtime/thread_storage_scope.h @@ -118,7 +118,9 @@ struct StorageScope { */ static StorageScope Create(const std::string& s) { StorageScope r; - if (s.compare(0, 6, "global") == 0) { + if (s.empty()) { + r.rank = StorageRank::kGlobal; + } else if (s.compare(0, 6, "global") == 0) { r.rank = StorageRank::kGlobal; r.tag = s.substr(6, std::string::npos); } else if (s.compare(0, 6, "shared") == 0) { diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 78f8a50e4e1b..9aec8f4e867b 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -84,7 +84,8 @@ class CodeGenAMDGPU : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { @@ -99,7 +100,7 @@ class CodeGenAMDGPU : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 48ccefafe3c4..bdae93b82aff 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -501,7 +501,8 @@ void CodeGenLLVM::GetAlignment(DataType t, const VarNode* buf_var, const PrimExp auto it = alloc_storage_info_.find(buf_var); if (it != alloc_storage_info_.end()) { const StorageInfo& info = it->second; - *p_native_bits = NativeVectorBits(info.scope); + *p_native_bits = + NativeVectorBits(runtime::StorageScope::Create(GetPtrStorageScope(GetRef(buf_var)))); max_align_bits = info.alignment * 8; } else { *p_native_bits = native_vector_bits_; @@ -1390,11 +1391,6 @@ void CodeGenLLVM::VisitStmt_(const AttrStmtNode* op) { analyzer_->Bind(iv->var, Range::FromMinExtent(0, op->value)); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - alloc_storage_info_[v].scope = - runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::storage_alignment) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index d5fcfab6d889..810e59be7214 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -163,8 +163,6 @@ class CodeGenLLVM : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief The alignment of allocation */ int alignment{0}; }; diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 9e56529ec9ef..43ea0e6b7ae9 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -59,8 +59,8 @@ class CodeGenNVPTX : public CodeGenLLVM { if (info.alignment > 16) { info.alignment = 16; } - - if (info.scope.rank == runtime::StorageRank::kLocal) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); + if (storage_scope.rank == runtime::StorageRank::kLocal) { // const int local_address_space = 5; // TODO(tqchen): for higher version of LLVM, local address space can be set. llvm::AllocaInst* alloca = WithFunctionEntry([&]() { @@ -75,7 +75,7 @@ class CodeGenNVPTX : public CodeGenLLVM { } buf = alloca; } else { - ICHECK(info.scope.rank == runtime::StorageRank::kShared) + ICHECK(storage_scope.rank == runtime::StorageRank::kShared) << "Can only allocate shared or local memory inside kernel"; // Shared memory: address space == 3 const unsigned shared_address_space = 3; diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index ae451f39f89b..834c57ac10fd 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -39,6 +39,7 @@ #include #include +#include "../../tir/transforms/ir_utils.h" #include "codegen_source_base.h" namespace tvm { diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index 6e76c3538e71..d7dcbec7ebe3 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -705,12 +705,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - auto it = alloc_storage_scope_.find(buffer); - ICHECK(it != alloc_storage_scope_.end()) - << "Buffer " << op->buffer_var << " is missing an AttrStmt with a \"storage_scope\" key"; - - std::string scope = it->second; + std::string scope = GetPtrStorageScope(op->buffer_var); if (scope.find("wmma.") == 0) { if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { ICHECK(op->dtype == DataType::Float(16) || op->dtype == DataType::Int(8) || @@ -724,6 +719,7 @@ void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { op->dtype == DataType::Int(32)) << "Accumulator only support half, float and int type for now"; } + const VarNode* buffer = op->buffer_var.as(); constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); PrintWmmaScope(scope, op->dtype, buffer, stream); } else { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 5d52bee44e98..c1fa921d4507 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -32,6 +32,7 @@ #include "../../runtime/pack_args.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../runtime/vulkan/vulkan_shader.h" +#include "../../tir/transforms/ir_utils.h" namespace tvm { namespace codegen { @@ -644,13 +645,14 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { ICHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); ICHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation in GPU"; + spirv::Value buf; - StorageInfo& info = storage_info_[op->buffer_var.get()]; + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(op->buffer_var)); spirv::SType etype = builder_->GetSType(op->dtype); - if (info.scope.rank == runtime::StorageRank::kLocal) { + if (storage_scope.rank == runtime::StorageRank::kLocal) { buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassFunction); - } else if (info.scope.rank == runtime::StorageRank::kShared) { + } else if (storage_scope.rank == runtime::StorageRank::kShared) { // Shared memory buf = builder_->Allocate(etype, static_cast(constant_size), spv::StorageClassWorkgroup); @@ -660,8 +662,10 @@ void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { builder_->SetName(buf, op->buffer_var->name_hint); + StorageInfo& info = storage_info_[op->buffer_var.get()]; ICHECK(!info.content_fixed); info.UpdateContentType(op->dtype); + ICHECK(!var_map_.count(op->buffer_var.get())); var_map_[op->buffer_var.get()] = buf; this->VisitStmt(op->body); @@ -677,10 +681,6 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmtNode* op) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); } } - } else if (op->attr_key == tir::attr::storage_scope) { - const VarNode* v = op->node.as(); - ICHECK(v); - storage_info_[v].scope = runtime::StorageScope::Create(op->value.as()->value); } else if (op->attr_key == tir::attr::volatile_scope) { const VarNode* v = op->node.as(); ICHECK(v); diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index 3868322a74e0..a44dc5fd3d34 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -116,8 +116,6 @@ class CodeGenSPIRV : public ExprFunctor, protected: /*! \brief The storage information */ struct StorageInfo { - /*! \brief The storage scope */ - runtime::StorageScope scope; /*! \brief Whether it is volatile */ bool is_volatile{false}; /*! \brief Whether it is volatile */ diff --git a/src/te/operation/compute_op.cc b/src/te/operation/compute_op.cc index 9a4eadb35619..26c08955f5ad 100644 --- a/src/te/operation/compute_op.cc +++ b/src/te/operation/compute_op.cc @@ -260,7 +260,7 @@ void BaseComputeOpNode::GatherBound(const Operation& self, Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Region bounds; for (IterVar iv : this->axis) { @@ -269,7 +269,7 @@ Stmt BaseComputeOpNode::BuildRealize(const Stage& stage, Stmt realize = body; for (int i = this->num_outputs(); i > 0; --i) { Tensor t = stage->op.output(i - 1); - realize = tir::ProducerRealize(t, bounds, const_true(), realize); + realize = tir::ProducerRealize(t, bounds, const_true(), realize, storage_scope); // alignment requirement, only useful for compute for (size_t i = 0; i < num_schedulable_dims(); ++i) { auto it = stage->iter_var_attrs.find(this->axis[i]); diff --git a/src/te/operation/create_primfunc.cc b/src/te/operation/create_primfunc.cc index 190892b2283f..a47556bac101 100644 --- a/src/te/operation/create_primfunc.cc +++ b/src/te/operation/create_primfunc.cc @@ -109,7 +109,7 @@ BlockRealize GenerateBlockFromTensor(const te::ComputeOp& compute_op, const te:: } // Step 2. Declare buffer and update op2buffers - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint()); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, tensor->GetNameHint(), "global"); info->tensor2buffers[tensor] = buffer; // Step 3. Add Buffer to root_alloc @@ -270,7 +270,8 @@ PrimFunc CreatePrimFunc(const Array& arg_list) { const te::Tensor& tensor = op.output(0); // Check op is in op list ICHECK(info.IsArg(tensor)); - const Buffer& buffer = decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name); + const Buffer& buffer = + decl_buffer(placeholder->shape, placeholder->dtype, placeholder->name, "global"); info.tensor2buffers[tensor] = buffer; } else if (const auto* compute_op = op.as()) { // Case 2. ComputeOp (te.compute) diff --git a/src/te/operation/cross_thread_reduction.cc b/src/te/operation/cross_thread_reduction.cc index da20dd875ba5..f844090ca6f5 100644 --- a/src/te/operation/cross_thread_reduction.cc +++ b/src/te/operation/cross_thread_reduction.cc @@ -146,7 +146,7 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, for (size_t i = 0; i < size; ++i) { DataType t = reduces[i]->dtype; normal_res_handles.emplace_back("normal_reduce_temp" + std::to_string(i), - PointerType(PrimType(t))); + PointerType(PrimType(t), "local")); lhs.push_back(Load(t, normal_res_handles[i], 0, const_true(t.lanes()))); } Array init_value = combiner->identity_element; @@ -177,7 +177,8 @@ Stmt MakeCrossThreadReduction(const ComputeOpNode* self, const Stage& stage, std::vector res_handles(size); for (size_t idx = 0; idx < size; ++idx) { DataType dtype = reduces[idx]->dtype; - res_handles[idx] = Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype))); + res_handles[idx] = + Var("reduce_temp" + std::to_string(idx), PointerType(PrimType(dtype), "local")); freduce_args.push_back(res_handles[idx]); } diff --git a/src/te/operation/extern_op.cc b/src/te/operation/extern_op.cc index 1c9a3cb336ae..b602efcfc28b 100644 --- a/src/te/operation/extern_op.cc +++ b/src/te/operation/extern_op.cc @@ -124,7 +124,7 @@ void ExternOpNode::GatherBound(const Operation& self, Stmt ExternOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; for (int k = 0; k < num_outputs(); ++k) { @@ -133,7 +133,7 @@ Stmt ExternOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/hybrid_op.cc b/src/te/operation/hybrid_op.cc index 65b8660ca1fb..5d2412abb3d2 100644 --- a/src/te/operation/hybrid_op.cc +++ b/src/te/operation/hybrid_op.cc @@ -144,7 +144,7 @@ void HybridOpNode::GatherBound(const Operation& self, Stmt HybridOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { // TODO(@were): Add attribute inject here and remove it from hybrid parser. ICHECK_EQ(stage->op.get(), this); Stmt realize_body = body; @@ -154,7 +154,7 @@ Stmt HybridOpNode::BuildRealize(const Stage& stage, for (size_t i = 0; i < t->shape.size(); ++i) { bounds.push_back(Range::FromMinExtent(make_const(t->shape[i].dtype(), 0), t->shape[i])); } - realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body); + realize_body = tir::ProducerRealize(t, bounds, const_true(), realize_body, storage_scope); } return realize_body; } diff --git a/src/te/operation/placeholder_op.cc b/src/te/operation/placeholder_op.cc index c51e53e16cd1..4f5df7ad3024 100644 --- a/src/te/operation/placeholder_op.cc +++ b/src/te/operation/placeholder_op.cc @@ -85,7 +85,7 @@ void PlaceholderOpNode::GatherBound(const Operation& self, Stmt PlaceholderOpNode::BuildRealize(const Stage& stage, const std::unordered_map& realize_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { return body; } diff --git a/src/te/operation/scan_op.cc b/src/te/operation/scan_op.cc index a555e86097b7..39689bd9654a 100644 --- a/src/te/operation/scan_op.cc +++ b/src/te/operation/scan_op.cc @@ -234,7 +234,7 @@ void ScanOpNode::GatherBound(const Operation& self, } Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_map& dom_map, - const Stmt& body) const { + const Stmt& body, String storage_scope) const { arith::Analyzer analyzer; ICHECK_EQ(stage->op.get(), this); Range sdom = dom_map.at(this->scan_axis); @@ -250,7 +250,7 @@ Stmt ScanOpNode::BuildRealize(const Stage& stage, const std::unordered_mapspatial_axis_[sp_idx]; bounds.push_back(dom_map.at(sp_ax)); } - ret = tir::ProducerRealize(t, bounds, const_true(), ret); + ret = tir::ProducerRealize(t, bounds, const_true(), ret, storage_scope); } return ret; } diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index 355e3c39494b..825092d20ac0 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -51,11 +51,8 @@ Stmt MakePipeline(const Stage& s, const std::unordered_map& dom_ if (consumer.defined() && !is_no_op(consumer)) { pipeline = SeqStmt({producer, consumer}); } - pipeline = s->op->BuildRealize(s, dom_map, pipeline); - // use attribute to mark scope of the operation. - pipeline = AttrStmt(s->op, tir::attr::realize_scope, StringImm(s->scope), pipeline); - return pipeline; + return s->op->BuildRealize(s, dom_map, pipeline, s->scope); } // inject the operator's realization on the stmt. @@ -175,8 +172,7 @@ class SchedulePostProc : public StmtExprMutator { thread_extent_scope_.erase(op->node.get()); return ret; } - } else if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + } else if (op->attr_key == tir::attr::double_buffer_scope) { auto it = replace_op_.find(op->node.get()); if (it != replace_op_.end()) { if (it->second.defined()) { @@ -218,7 +214,8 @@ class SchedulePostProc : public StmtExprMutator { auto it = replace_realize_.find(key); if (it != replace_realize_.end()) { if (it->second.defined()) { - Stmt ret = ProducerRealize(it->second, op->bounds, op->condition, op->body); + Stmt ret = + ProducerRealize(it->second, op->bounds, op->condition, op->body, op->storage_scope); return this->VisitStmt(ret); } else { return this->VisitStmt(op->body); diff --git a/src/te/schedule/schedule_postproc_to_primfunc.cc b/src/te/schedule/schedule_postproc_to_primfunc.cc index 5c59961fe011..2063fc7cad6a 100644 --- a/src/te/schedule/schedule_postproc_to_primfunc.cc +++ b/src/te/schedule/schedule_postproc_to_primfunc.cc @@ -49,12 +49,12 @@ namespace tvm { namespace te { // create a buffer for tensor. -Buffer CreateBufferFor(const Tensor& tensor) { +Buffer CreateBufferFor(const Tensor& tensor, String storage_scope = "") { std::string name = tensor->op->name; if (tensor->op->num_outputs() != 1) { name += ".v" + std::to_string(tensor->value_index); } - Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name); + Buffer buffer = decl_buffer(tensor->shape, tensor->dtype, name, storage_scope); return buffer; } @@ -67,10 +67,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const AttrStmtNode* op) final { auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); - // TODO(tvm-team): remove realize_scope, turn the info into - // Buffer's scope field in this pass. - if (op->attr_key == tir::attr::realize_scope || - op->attr_key == tir::attr::double_buffer_scope) { + if (op->attr_key == tir::attr::double_buffer_scope) { Stmt body = op->body; Operation operation = Downcast(op->node); for (int i = operation->num_outputs(); i != 0; --i) { @@ -95,7 +92,7 @@ class TensorToBufferMapper : public StmtExprMutator { Stmt VisitStmt_(const ProducerRealizeNode* op) final { Tensor tensor = Downcast(op->producer); - Buffer buffer = GetOrAllocBuffer(tensor); + Buffer buffer = GetOrAllocBuffer(tensor, op->storage_scope); auto ret = StmtExprMutator::VisitStmt_(op); op = ret.as(); @@ -122,14 +119,16 @@ class TensorToBufferMapper : public StmtExprMutator { } private: - Buffer GetOrAllocBuffer(const Tensor& tensor) { return GetBuffer(tensor, true); } + Buffer GetOrAllocBuffer(const Tensor& tensor, String storage_scope = "") { + return GetBuffer(tensor, storage_scope, true); + } - Buffer GetBuffer(const Tensor& tensor, bool allow_alloc = false) { + Buffer GetBuffer(const Tensor& tensor, String storage_scope = "", bool allow_alloc = false) { auto it = buffer_map_.find(tensor); if (it != buffer_map_.end()) return it->second; ICHECK(allow_alloc) << "Cannot find the Realization point of tensor " << tensor; - auto buffer = CreateBufferFor(tensor); + auto buffer = CreateBufferFor(tensor, storage_scope); buffer_map_[tensor] = buffer; return buffer; } diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc index 1667eb7d1fbd..e2fcf89d8966 100644 --- a/src/tir/ir/buffer.cc +++ b/src/tir/ir/buffer.cc @@ -45,9 +45,10 @@ Array SimplifyArray(arith::Analyzer* ana, Array array) { return array; } -Buffer decl_buffer(Array shape, DataType dtype, String name, Span span) { +Buffer decl_buffer(Array shape, DataType dtype, String name, String storage_scope, + Span span) { DataType storage_dtype = (dtype == DataType::Bool() ? DataType::Int(8) : dtype); - return Buffer(Var(name, PointerType(PrimType(storage_dtype)), span), dtype, shape, + return Buffer(Var(name, PointerType(PrimType(storage_dtype), storage_scope), span), dtype, shape, Array(), PrimExpr(), name, "", 0, 0, kDefault, span); } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b2016eb74c91..42ef60bb86d7 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -61,6 +61,16 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // AttrStmt AttrStmt::AttrStmt(ObjectRef node, String attr_key, PrimExpr value, Stmt body, Span span) { + if (attr_key == attr::storage_scope) { + const VarNode* buf = node.as(); + ICHECK(buf); + const auto* ptr_type = buf->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + auto attr_scope = value.as()->value; + ICHECK_EQ(attr_scope, ptr_type->storage_scope) + << "Storage scopes attached to AttrStmt and buffer var are different. " << attr_scope + << ", " << ptr_type->storage_scope; + } auto n = make_object(); n->node = node; n->attr_key = std::move(attr_key); @@ -377,7 +387,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) // ProducerRealize ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, Span span) { + Stmt body, String storage_scope, Span span) { for (size_t i = 0; i < bounds.size(); ++i) { ICHECK(bounds[i]->min.defined()); ICHECK(bounds[i]->extent.defined()); @@ -394,13 +404,14 @@ ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr node->condition = std::move(condition); node->body = std::move(body); node->span = std::move(span); + node->storage_scope = std::move(storage_scope); data_ = std::move(node); } TVM_REGISTER_GLOBAL("tir.ProducerRealize") .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - Span span) { - return ProducerRealize(producer, bounds, condition, body, span); + String storage_scope, Span span) { + return ProducerRealize(producer, bounds, condition, body, storage_scope, span); }); TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index cbae3f95ec68..f7ece25d3fcd 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -201,5 +201,11 @@ class IRConvertSSA final : public StmtExprMutator { Stmt ConvertSSA(Stmt stmt) { return IRConvertSSA()(std::move(stmt)); } +String GetPtrStorageScope(Var buffer_var) { + const auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return ptr_type->storage_scope; +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 906ff8a38b6c..b5a154b707af 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -191,6 +191,12 @@ inline PrimExpr StackAlloca(std::string type, size_t num) { */ Stmt ConvertSSA(Stmt stmt); +/*! + * \brief Return the storage scope associated with a buffer variable. + * \param buffer_var The input buffer variable. + * \return A string representing the storage scope of this buffer variable. + */ +String GetPtrStorageScope(Var buffer_var); } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc index 9e536814fa12..25a2f4e060dd 100644 --- a/src/tir/transforms/lower_thread_allreduce.cc +++ b/src/tir/transforms/lower_thread_allreduce.cc @@ -33,10 +33,33 @@ #include "../../runtime/thread_storage_scope.h" #include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { +class UpdatePointerStorageScopeAllReduce final : public UpdatePointerStorageScope { + public: + explicit UpdatePointerStorageScopeAllReduce( + const std::unordered_map& new_storage_scopes) + : UpdatePointerStorageScope(new_storage_scopes) {} + + Stmt VisitStmt_(const AllocateNode* op) final { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + auto new_scope = GetPtrStorageScope(remapped); + if (new_scope != GetPtrStorageScope(op->buffer_var)) { + Stmt body = StmtExprMutator::VisitStmt(op->body); + if (new_scope == "shared") { + // use volatile access to shared buffer. + body = AttrStmt(remapped, attr::volatile_scope, 1, body); + } + body = Allocate(remapped, op->dtype, op->extents, op->condition, body); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), body); + } + return StmtExprMutator::VisitStmt_(op); + } +}; + class ThreadAllreduceBuilder final : public StmtExprMutator { public: explicit ThreadAllreduceBuilder(const TargetNode* target) @@ -86,12 +109,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = it->second.as(); if (warp_allocs_.count(repl)) { stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), stmt); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } else { - // use volatile access to shared buffer. - stmt = AttrStmt(repl->buffer_var, attr::volatile_scope, 1, op->body); - stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, stmt); - stmt = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("shared"), stmt); + stmt = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body); + new_storage_scopes_[repl->buffer_var.get()] = "shared"; } return stmt; } else { @@ -108,6 +129,8 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { } } + std::unordered_map new_storage_scopes_; + private: // Thread entry struct ThreadEntry { @@ -366,7 +389,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator { const AllocateNode* repl = var.as(); if (repl) { body = Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, body); - body = AttrStmt(repl->buffer_var, attr::storage_scope, StringImm("local"), body); + new_storage_scopes_[repl->buffer_var.get()] = "local"; } } @@ -590,7 +613,10 @@ Pass LowerThreadAllreduce() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerThreadAllreduce: Require the target attribute"; const TargetNode* target_node = target.as(); - n->body = ThreadAllreduceBuilder(target_node)(n->body); + ThreadAllreduceBuilder thread_all_reduce(target_node); + auto reduce_body = thread_all_reduce(n->body); + n->body = + UpdatePointerStorageScopeAllReduce(thread_all_reduce.new_storage_scopes_)(reduce_body); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerThreadAllreduce", {}); diff --git a/src/tir/transforms/lower_warp_memory.cc b/src/tir/transforms/lower_warp_memory.cc index b95681a936ca..060b02c3d137 100644 --- a/src/tir/transforms/lower_warp_memory.cc +++ b/src/tir/transforms/lower_warp_memory.cc @@ -40,6 +40,8 @@ #include "../../arith/pattern_match.h" #include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" +#include "update_pointer_storage_scope.h" namespace tvm { namespace tir { @@ -356,6 +358,8 @@ class WarpMemoryRewriter : private StmtMutator { return stmt; } + std::unordered_map new_storage_scopes_; + private: Stmt VisitStmt_(const AllocateNode* op) { auto ret = StmtMutator::VisitStmt_(op); @@ -374,9 +378,7 @@ class WarpMemoryRewriter : private StmtMutator { StorageScope scope = StorageScope::Create(op->value.as()->value); if (scope.rank == runtime::StorageRank::kWarp) { warp_buffer_.insert(buf); - Stmt ret = StmtMutator::VisitStmt_(op); - op = ret.as(); - return AttrStmt(op->node, op->attr_key, StringImm("local"), op->body); + new_storage_scopes_[buf] = "local"; } } return StmtMutator::VisitStmt_(op); @@ -397,7 +399,9 @@ Pass LowerWarpMemory() { auto target = f->GetAttr(tvm::attr::kTarget); ICHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; int warp_size = target.value()->GetAttr("thread_warp_size", 1).value(); - n->body = WarpMemoryRewriter(warp_size).Rewrite(std::move(n->body)); + WarpMemoryRewriter warp_memory_rewriter(warp_size); + auto stmt = warp_memory_rewriter.Rewrite(std::move(n->body)); + n->body = UpdatePointerStorageScope(warp_memory_rewriter.new_storage_scopes_)(stmt); return f; }; return CreatePrimFuncPass(pass_func, 0, "tir.LowerWarpMemory", {}); diff --git a/src/tir/transforms/storage_access.cc b/src/tir/transforms/storage_access.cc index 00002d3587db..9dae0006facd 100644 --- a/src/tir/transforms/storage_access.cc +++ b/src/tir/transforms/storage_access.cc @@ -35,7 +35,7 @@ namespace tir { void StorageAccessVisitor::VisitExpr_(const LoadNode* op) { const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { ICHECK(allow_append_) << op << " " << scope.to_string(); AccessEntry e; @@ -56,7 +56,7 @@ void StorageAccessVisitor::VisitStmt_(const StoreNode* op) { ICHECK_EQ(curr_stmt_.access.size(), 0U); curr_stmt_.stmt = op; const VarNode* buf = op->buffer_var.as(); - StorageScope scope = GetScope(buf); + StorageScope scope = GetScope(op->buffer_var); if (Enabled(buf, scope)) { AccessEntry e; e.threads = env_threads(); @@ -90,11 +90,7 @@ void StorageAccessVisitor::VisitStmt_(const EvaluateNode* op) { } void StorageAccessVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); - } else if (op->attr_key == attr::double_buffer_write) { + if (op->attr_key == attr::double_buffer_write) { ICHECK(double_buffer_write_ == nullptr); double_buffer_write_ = op->node.as(); scope_.push_back(std::vector()); @@ -208,7 +204,7 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { PrimExpr offset = op->args[2]; PrimExpr extent = op->args[3]; const IntImmNode* flag = op->args[4].as(); - StorageScope scope = GetScope(buffer); + StorageScope scope = GetScope(GetRef(buffer)); // The buffer scope. if (Enabled(buffer, scope)) { ICHECK(allow_append_); @@ -244,12 +240,11 @@ void StorageAccessVisitor::VisitExpr_(const CallNode* op) { } } -StorageScope StorageAccessVisitor::GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; +StorageScope StorageAccessVisitor::GetScope(Var buffer_var) const { + if (buffer_var->type_annotation.as()) { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); + } + return StorageScope(); // global by default } } // namespace tir diff --git a/src/tir/transforms/storage_access.h b/src/tir/transforms/storage_access.h index 663c570fd15c..9dc4c923b054 100644 --- a/src/tir/transforms/storage_access.h +++ b/src/tir/transforms/storage_access.h @@ -118,7 +118,7 @@ class StorageAccessVisitor : public StmtExprVisitor { * \brief Get the scope of the buffer array. * \return The scope of the final buffer array. */ - StorageScope GetScope(const VarNode* buf) const; + StorageScope GetScope(Var buffer_var) const; // access scope std::vector > scope_; @@ -135,8 +135,6 @@ class StorageAccessVisitor : public StmtExprVisitor { StmtEntry curr_stmt_; // The involving threads Array env_threads_; - // The storage scope of each buffer - std::unordered_map storage_scope_; }; } // namespace tir diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 43fc1f1ec53f..0db86130a8da 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -78,11 +78,7 @@ class StorageFlattener : public StmtExprMutator { } Stmt VisitStmt_(const AttrStmtNode* op) final { - if (op->attr_key == attr::realize_scope) { - storage_scope_[op->node.get()] = op->value.as()->value; - return this->VisitStmt(op->body); - } else if (op->attr_key == attr::double_buffer_scope && - op->node->IsInstance()) { + if (op->attr_key == attr::double_buffer_scope && op->node->IsInstance()) { auto buffer = Downcast(op->node); Stmt body = this->VisitStmt(op->body); auto it = buf_map_.find(buffer); @@ -156,10 +152,8 @@ class StorageFlattener : public StmtExprMutator { shape.push_back(r->extent); } // deduce current storage scope. - auto it = storage_scope_.find(op->buffer.get()); - ICHECK(it != storage_scope_.end()) << "Cannot find storage scope of " << op->buffer; StorageScope skey; - const std::string& strkey = it->second; + std::string strkey = GetPtrStorageScope(op->buffer->data); if (strkey.length() == 0) { if (curr_thread_scope_.size() != 0) { skey.rank = runtime::DefaultStorageRank(curr_thread_scope_.back().rank); @@ -167,7 +161,6 @@ class StorageFlattener : public StmtExprMutator { } else { skey = StorageScope::Create(strkey); } - // use small alignment for small arrays auto dtype = op->buffer->dtype; int32_t const_size = AllocateNode::constant_allocation_size(shape); @@ -200,8 +193,11 @@ class StorageFlattener : public StmtExprMutator { strides = Array(rstrides.rbegin(), rstrides.rend()); } - e.buffer = Buffer(Var(op->buffer->data->name_hint, op->buffer->data->type_annotation), - op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, + auto* ptr_type = op->buffer->data->type_annotation.as(); + ICHECK(ptr_type); + auto new_var = + Var(op->buffer->data->name_hint, PointerType(ptr_type->element_type, skey.to_string())); + e.buffer = Buffer(new_var, op->buffer->dtype, shape, strides, PrimExpr(), op->buffer->name, skey.to_string(), align, 0, kDefault); buf_map_[key] = e; @@ -491,8 +487,6 @@ class StorageFlattener : public StmtExprMutator { std::unordered_map buf_map_; // Dimension alignment std::unordered_map, ObjectPtrHash, ObjectPtrEqual> dim_align_; - // Storage scope - std::unordered_map storage_scope_; // The current thread scope. std::vector curr_thread_scope_; // Collects shapes. diff --git a/src/tir/transforms/storage_rewrite.cc b/src/tir/transforms/storage_rewrite.cc index c755576e2b88..613d02614b39 100644 --- a/src/tir/transforms/storage_rewrite.cc +++ b/src/tir/transforms/storage_rewrite.cc @@ -75,8 +75,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { }; // The scope of each allocation struct AllocEntry { - // Scope used for allocation. - StorageScope storage_scope; // scope level size_t level{0}; // allocation stmt @@ -86,13 +84,8 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { void VisitStmt_(const AllocateNode* op) final { size_t level = scope_.size(); const VarNode* buf = op->buffer_var.get(); - auto it = alloc_info_.find(buf); - ICHECK(it != alloc_info_.end()) << "Could not find buffer `" << buf->name_hint - << "` in the list of allocated buffers. Perhaps you are " - "missing a storage_scope attr for this buffer."; - ICHECK(it->second.alloc == nullptr); - it->second.alloc = op; - it->second.level = level; + alloc_info_[buf].alloc = op; + alloc_info_[buf].level = level; StmtExprVisitor::VisitStmt_(op); } void VisitStmt_(const StoreNode* op) final { @@ -180,10 +173,6 @@ class LinearAccessPatternFinder final : public StmtExprVisitor { VisitNewScope(op); } else if (op->attr_key == attr::virtual_thread) { VisitNewScope(op); - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - alloc_info_[buf].storage_scope = StorageScope::Create(op->value.as()->value); - StmtExprVisitor::VisitStmt_(op); } else { StmtExprVisitor::VisitStmt_(op); } @@ -716,7 +705,8 @@ class StoragePlanRewriter : public StmtExprMutator { for (const VarNode* var : it->second.gen) { ICHECK(alloc_info.count(var)); - const AllocEntry& ae = alloc_info.at(var); + const AllocateNode* alloc = alloc_info.at(var).alloc; + auto storage_scope = StorageScope::Create(GetPtrStorageScope(GetRef(var))); StorageEntry* dst_entry = nullptr; // inplace detection if (detect_inplace) { @@ -726,13 +716,12 @@ class StoragePlanRewriter : public StmtExprMutator { if (!inplace_flag.count(src) && alloc_map_.count(src)) { InplaceOpVerifier visitor; StorageEntry* src_entry = alloc_map_.at(src); - if (src_entry->scope == ae.storage_scope && + if (src_entry->scope == storage_scope && src_entry->attach_scope_ == thread_scope_ && - src_entry->elem_type == ae.alloc->dtype.element_of() && + src_entry->elem_type == alloc->dtype.element_of() && visitor.Check(s.stmt, var, src)) { - uint64_t const_nbits = - static_cast(ae.alloc->constant_allocation_size()) * - ae.alloc->dtype.bits() * ae.alloc->dtype.lanes(); + uint64_t const_nbits = static_cast(alloc->constant_allocation_size()) * + alloc->dtype.bits() * alloc->dtype.lanes(); if (src_entry->const_nbits == const_nbits && !inplace_found) { // successfully inplace dst_entry = src_entry; @@ -744,9 +733,9 @@ class StoragePlanRewriter : public StmtExprMutator { } } if (dst_entry == nullptr) { - dst_entry = FindAlloc(ae.alloc, thread_scope_, ae.storage_scope); + dst_entry = FindAlloc(alloc, thread_scope_, storage_scope); } - dst_entry->allocs.emplace_back(ae.alloc); + dst_entry->allocs.emplace_back(alloc); alloc_map_[var] = dst_entry; } } @@ -933,7 +922,8 @@ class VectorAllocRewriter : public StmtExprMutator { extents[extents.size() - 1] / make_const(extents[0].dtype(), factor)); // create a new buffer var DataType new_dtype = tvec[0]; - Var new_buffer_var(op->buffer_var->name_hint, PointerType(PrimType(new_dtype))); + Var new_buffer_var(op->buffer_var->name_hint, + PointerType(PrimType(new_dtype), GetPtrStorageScope(op->buffer_var))); // update the remap req. var_remap_.Set(op->buffer_var, new_buffer_var); return Allocate(new_buffer_var, new_dtype, extents, op->condition, op->body); diff --git a/src/tir/transforms/thread_storage_sync.cc b/src/tir/transforms/thread_storage_sync.cc index 8f757171afbd..35e4563b8f58 100644 --- a/src/tir/transforms/thread_storage_sync.cc +++ b/src/tir/transforms/thread_storage_sync.cc @@ -223,14 +223,14 @@ class ThreadSyncInserter : public StmtExprMutator { } PrimExpr VisitExpr_(const LoadNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].read_count; } return StmtExprMutator::VisitExpr_(op); } Stmt VisitStmt_(const StoreNode* op) final { if (sync_scope_.rank == StorageRank::kGlobal && - GetScope(op->buffer_var.get()).rank == StorageRank::kGlobal) { + GetScope(op->buffer_var).rank == StorageRank::kGlobal) { ++rw_stats_[op->buffer_var].write_count; } return StmtExprMutator::VisitStmt_(op); @@ -250,10 +250,6 @@ class ThreadSyncInserter : public StmtExprMutator { is_lead_ = PrimExpr(); } return ret; - } else if (op->attr_key == attr::storage_scope) { - const VarNode* buf = op->node.as(); - storage_scope_[buf] = StorageScope::Create(op->value.as()->value); - return StmtExprMutator::VisitStmt_(op); } else { return StmtExprMutator::VisitStmt_(op); } @@ -264,16 +260,15 @@ class ThreadSyncInserter : public StmtExprMutator { PrimExpr expr = StmtExprMutator::VisitExpr_(op); op = expr.as(); ICHECK_EQ(op->args.size(), 5U); - const VarNode* buffer_var = op->args[1].as(); - Var var(GetRef(buffer_var)); + Var buffer_var(GetRef(op->args[1].as())); const IntImmNode* flag = op->args[4].as(); if ((flag->value & 1) && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].read_count; + ++rw_stats_[buffer_var].read_count; } if (flag->value & 2 && sync_scope_.rank == StorageRank::kGlobal && GetScope(buffer_var).rank == StorageRank::kGlobal) { - ++rw_stats_[var].write_count; + ++rw_stats_[buffer_var].write_count; } return expr; } else { @@ -287,14 +282,12 @@ class ThreadSyncInserter : public StmtExprMutator { int read_count{0}; int write_count{0}; }; + // Get current storage scope. - StorageScope GetScope(const VarNode* buf) const { - auto it = storage_scope_.find(buf); - StorageScope s; - s.rank = StorageRank::kGlobal; - if (it == storage_scope_.end()) return s; - return it->second; + StorageScope GetScope(Var buffer_var) const { + return StorageScope::Create(GetPtrStorageScope(buffer_var)); } + // private functions. Stmt InitGlobalBarrier(const AttrStmtNode* op) { ICHECK(op != nullptr); @@ -337,8 +330,6 @@ class ThreadSyncInserter : public StmtExprMutator { // data structure. StorageScope sync_scope_; const std::unordered_set& syncs_; - // The storage scope of each buffer - std::unordered_map storage_scope_; // The read write statistics of storage std::unordered_map rw_stats_; // The statistics for global barrier diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc new file mode 100644 index 000000000000..0ae02fec9f95 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.cc @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.cc + * \brief A pass to update storage scopes for buffer variables. + */ +#include "update_pointer_storage_scope.h" + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +Var WithStorageScope(const VarNode* buffer_var, String storage_scope) { + auto* ptr_type = buffer_var->type_annotation.as(); + ICHECK(ptr_type) << "The provided variable is not of pointer type"; + return Var(buffer_var->name_hint, PointerType(ptr_type->element_type, storage_scope), + buffer_var->span); +} + +UpdatePointerStorageScope::UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes) { + for (auto& kv : new_storage_scopes) { + new_var_remap_[kv.first] = WithStorageScope(kv.first, kv.second); + } +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const VarNode* op) { + auto it = new_var_remap_.find(op); + if (it == new_var_remap_.end()) { + return GetRef(op); + } + return it->second; +} + +PrimExpr UpdatePointerStorageScope::VisitExpr_(const LoadNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Load(op->dtype, Downcast(remapped), StmtExprMutator::VisitExpr(op->index), + StmtExprMutator::VisitExpr(op->predicate)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const AttrStmtNode* op) { + if (op->attr_key == attr::storage_scope) { + const VarNode* buf = op->node.as(); + auto remapped = Downcast(StmtExprMutator::VisitExpr(GetRef(buf))); + auto new_scope = GetPtrStorageScope(remapped); + return AttrStmt(remapped, attr::storage_scope, StringImm(new_scope), + StmtMutator::VisitStmt(op->body)); + } + return StmtMutator::VisitStmt_(op); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const AllocateNode* op) { + auto remapped = Downcast(StmtExprMutator::VisitExpr(op->buffer_var)); + return Allocate(remapped, op->dtype, op->extents, StmtExprMutator::VisitExpr(op->condition), + StmtExprMutator::VisitStmt(op->body)); +} + +Stmt UpdatePointerStorageScope::VisitStmt_(const StoreNode* op) { + auto remapped = StmtExprMutator::VisitExpr(op->buffer_var); + return Store(Downcast(remapped), StmtExprMutator::VisitExpr(op->value), + StmtExprMutator::VisitExpr(op->index), StmtExprMutator::VisitExpr(op->predicate)); +} + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h new file mode 100644 index 000000000000..481536a45b27 --- /dev/null +++ b/src/tir/transforms/update_pointer_storage_scope.h @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file update_pointer_storage_scope.h + * \brief A pass to update storage scopes for buffer variables. + */ +#ifndef TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ +#define TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ + +#include +#include +#include + +#include + +namespace tvm { +namespace tir { + +class UpdatePointerStorageScope : public StmtExprMutator { + public: + explicit UpdatePointerStorageScope( + const std::unordered_map& new_storage_scopes); + + virtual PrimExpr VisitExpr_(const VarNode*); + virtual PrimExpr VisitExpr_(const LoadNode*); + virtual Stmt VisitStmt_(const AttrStmtNode*); + virtual Stmt VisitStmt_(const AllocateNode*); + virtual Stmt VisitStmt_(const StoreNode*); + + private: + std::unordered_map new_var_remap_; +}; + +} // namespace tir +} // namespace tvm +#endif // TVM_TIR_TRANSFORMS_UPDATE_POINTER_STORAGE_SCOPE_H_ diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index 30b96546f991..e9626e7f31b4 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -189,9 +189,7 @@ def fanout(n, a): assert ir.min.value == 0 assert tvm.ir.structural_equal(ir.extent, n - 3) # Check loopbody - ibody = ir.body - assert isinstance(ibody, tvm.tir.AttrStmt) - abody = ibody.body + abody = ir.body assert isinstance(abody, tvm.tir.ProducerRealize) assert abody.bounds[0].min.value == 0 assert abody.bounds[0].extent.value == 1 diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index e2c2f7f7e0e5..ae5e7051bfba 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -379,8 +379,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body, tvm.tir.For) - assert stmt.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body, tvm.tir.For) + assert stmt.body.loop_var.name == C.op.axis[0].var.name if __name__ == "__main__": diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index ed4a21397885..2931925965b7 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -309,7 +309,7 @@ def get_B1_realize(x): ret = [] tvm.tir.stmt_functor.post_order_visit(stmt, get_B1_realize) - assert stmt.node == C.op and len(ret) == 1 + assert stmt.producer == C and len(ret) == 1 def test_tensor_inputs(): diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index c997748649cd..6929a329ac0f 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -35,7 +35,7 @@ def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: with tir.block([]): tir.reads(A[i, 0:16]) tir.writes(C[i, 0:16]) - B = tir.alloc_buffer([1, 16], "float32") + B = tir.alloc_buffer([1, 16], "float32", scope="global") for j in range(0, 16): with tir.block() as []: tir.reads(A[i, j]) @@ -111,7 +111,7 @@ def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32 with tir.block([]): tir.reads(A[i, m]) tir.writes(C[i, m]) - B = tir.alloc_buffer((m,), "float32") + B = tir.alloc_buffer((m,), "float32", scope="global") for j in range(0, m): with tir.block([]) as []: tir.reads(A[i, j]) @@ -190,8 +190,8 @@ def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: with tir.block([]) as []: tir.reads(A[i]) tir.writes(D[i]) - B = tir.alloc_buffer((32,)) - C = tir.alloc_buffer((32,)) + B = tir.alloc_buffer((32,), scope="global") + C = tir.alloc_buffer((32,), scope="global") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 diff --git a/tests/python/unittest/test_tir_transform_hoist_if.py b/tests/python/unittest/test_tir_transform_hoist_if.py index 252a187dbdc5..b111e2be75c7 100644 --- a/tests/python/unittest/test_tir_transform_hoist_if.py +++ b/tests/python/unittest/test_tir_transform_hoist_if.py @@ -636,7 +636,7 @@ def test_hoisting_block_scope_4(): def test_hoisting_block_scope_5(): ib = tvm.tir.ir_builder.create() - data = ib.pointer("float32", name="data") + data = ib.pointer("float32", name="data", scope="global") l = te.var("l") m = te.var("m") n = te.var("n") diff --git a/tests/python/unittest/test_tir_transform_loop_partition.py b/tests/python/unittest/test_tir_transform_loop_partition.py index 9e8848083908..6194024748e0 100644 --- a/tests/python/unittest/test_tir_transform_loop_partition.py +++ b/tests/python/unittest/test_tir_transform_loop_partition.py @@ -40,7 +40,7 @@ def test_basic(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) assert any(collect_visit(stmt.body.body[1], lambda x: isinstance(x, tvm.tir.IfThenElse))) @@ -156,7 +156,7 @@ def test_thread_axis(): mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([], stmt)) mod = tvm.tir.transform.LoopPartition()(mod) - stmt = tvm.tir.transform.Simplify()(mod)["main"].body + stmt = tvm.tir.transform.Simplify()(mod)["main"] assert not any(collect_visit(stmt.body.body[0], lambda x: isinstance(x, tvm.tir.IfThenElse))) diff --git a/tests/python/unittest/test_tir_transform_lower_warp_memory.py b/tests/python/unittest/test_tir_transform_lower_warp_memory.py index ef474c15cfbb..f3baff120cf6 100644 --- a/tests/python/unittest/test_tir_transform_lower_warp_memory.py +++ b/tests/python/unittest/test_tir_transform_lower_warp_memory.py @@ -72,8 +72,8 @@ def test_lower_warp_memory_correct_indices(): bounds = tvm.te.schedule.InferBound(s) ir = tvm.te.schedule.ScheduleOps(s, bounds) - inner_func = ir.body.body.body.body - store_A_warp = inner_func.body.seq[0].body.body + inner_func = ir.body.body.body + store_A_warp = inner_func.seq[0].body.body indices = list(store_A_warp.indices) # A.warp is actually many buffers, one for each warp, although they are all called A.warp diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 164949552859..6c0e228e8e4c 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -277,8 +277,8 @@ def mmult( } ) # var definition - C_global = tir.var("handle") - packedB = tir.var("handle") + C_global = tir.buffer_var("float32", "global") + packedB = tir.buffer_var("float32", "global") # body assert num_args == 3, "mmult: num_args should be 3" arg0: ty.handle = tir.tvm_struct_get(args, 0, 12, dtype="handle")