diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index fe0d9ed44ae6..5bc492fcefb8 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -178,42 +178,6 @@ class AssertStmtNode : public StmtNode { TVM_DECLARE_FINAL_OBJECT_INFO(AssertStmtNode, StmtNode); }; -// TODO(tvm-team): consider consolidate with AttrStmt. -/*! \brief annotation node of producer/consumer relation. */ -class ProducerConsumerNode : public StmtNode { - public: - /*! \brief The corresponding tensor. */ - FunctionRef func; - /*! \brief Whether the relation is producer. */ - bool is_producer; - /*! \brief Body to be executed. */ - Stmt body; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("func", &func); - v->Visit("is_producer", &is_producer); - v->Visit("body", &body); - } - - bool SEqualReduce(const ProducerConsumerNode* other, SEqualReducer equal) const { - return - equal(func, other->func) && - equal(is_producer, other->is_producer) && - equal(body, other->body); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(func); - hash_reduce(is_producer); - hash_reduce(body); - } - - TVM_DLL static Stmt make(FunctionRef func, bool is_producer, Stmt body); - - static constexpr const char* _type_key = "ProducerConsumer"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerConsumerNode, StmtNode); -}; - /*! * \brief Store value to the buffer. * @@ -385,10 +349,6 @@ class AllocateNode : public StmtNode { PrimExpr condition; /*! \brief The body to be executed. */ Stmt body; - // The following two fields are deprecated - // kept for backward compatibility and will be refactored later. - PrimExpr new_expr; - std::string free_function; void VisitAttrs(AttrVisitor* v) { v->Visit("buffer_var", &buffer_var); @@ -419,9 +379,7 @@ class AllocateNode : public StmtNode { DataType dtype, Array extents, PrimExpr condition, - Stmt body, - PrimExpr new_expr = PrimExpr(), - std::string free_function = std::string()); + Stmt body); /*! * \brief If the buffer size is constant, return the size. @@ -589,8 +547,6 @@ class SeqStmt : public Stmt { * * - When an argument is nullptr, it will be ignored. * - When an argument is an array or a SeqStmt, it will be flattened recursively. - * - When an argument is a consumer block in ProducerConsumer, the consumer - * tag will be dropped as such information is not useful in lowering. * - A normal Stmt will be appended to the end of the sequence. * * \note This function can directly return an element @@ -618,13 +574,6 @@ class SeqStmt : public Stmt { if (!stmt.defined()) return; if (auto* op = stmt.as()) { operator()(0, op->seq); - } else if (auto* op = stmt.as()) { - // NOTE: The consumer block annotation was not as useful and can be safely dropped. - if (!op->is_producer) { - operator()(0, op->body); - } else { - seq_->push_back(stmt); - } } else { seq_->push_back(stmt); } diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index ad5c5cd60d31..f93e9080a377 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -94,7 +94,6 @@ class StmtFunctor { virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const FreeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerConsumerNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProvideNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const RealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -117,7 +116,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(StoreNode); IR_STMT_FUNCTOR_DISPATCH(FreeNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProducerConsumerNode); IR_STMT_FUNCTOR_DISPATCH(ProvideNode); IR_STMT_FUNCTOR_DISPATCH(RealizeNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); @@ -158,7 +156,6 @@ class TVM_DLL StmtVisitor : void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const FreeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProducerConsumerNode* op) override; void VisitStmt_(const ProvideNode* op) override; void VisitStmt_(const RealizeNode* op) override; void VisitStmt_(const PrefetchNode* op) override; @@ -253,7 +250,6 @@ class TVM_DLL StmtMutator : Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const FreeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProducerConsumerNode* op) override; Stmt VisitStmt_(const ProvideNode* op) override; Stmt VisitStmt_(const RealizeNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index a50c10dee357..d2238ad754ac 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -27,7 +27,7 @@ from .expr import Select, BufferLoad, Load, Ramp, Broadcast, Shuffle, Call, Let from .expr import IterVar, Any -from .stmt import Stmt, LetStmt, AssertStmt, ProducerConsumer, For +from .stmt import Stmt, LetStmt, AssertStmt, For from .stmt import BufferStore, Store, Provide, Allocate, AttrStmt, Free, Realize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 4531cdfc35ac..c5b2a7957319 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -76,26 +76,6 @@ def __init__(self, condition, message, body): _ffi_api.AssertStmt, condition, message, body) -@tvm._ffi.register_object -class ProducerConsumer(Stmt): - """ProducerConsumer node. - - Parameters - ---------- - func : Operation - The Operation. - - is_producer : bool - Whether if the node is producer. - - body : Stmt - The body statement. - """ - def __init__(self, func, is_producer, body): - self.__init_handle_by_constructor__( - _ffi_api.ProducerConsumer, func, is_producer, body) - - @tvm._ffi.register_object class For(Stmt): """For node. @@ -425,6 +405,4 @@ def stmt_list(stmt): for x in stmt: res += stmt_list(x) return res - if isinstance(stmt, ProducerConsumer): - return stmt_list(stmt.body) return [stmt] diff --git a/src/contrib/hybrid/codegen_hybrid.cc b/src/contrib/hybrid/codegen_hybrid.cc index a17ae8786877..bb97900833dd 100644 --- a/src/contrib/hybrid/codegen_hybrid.cc +++ b/src/contrib/hybrid/codegen_hybrid.cc @@ -399,10 +399,6 @@ void CodeGenHybrid::VisitStmt_(const EvaluateNode* op) { stream << str << "\n"; } -void CodeGenHybrid::VisitStmt_(const ProducerConsumerNode* op) { - PrintStmt(op->body); -} - void CodeGenHybrid::PrintIndent() { stream << std::string(indent_, ' '); } diff --git a/src/contrib/hybrid/codegen_hybrid.h b/src/contrib/hybrid/codegen_hybrid.h index 9784defcba88..d282edbb1926 100644 --- a/src/contrib/hybrid/codegen_hybrid.h +++ b/src/contrib/hybrid/codegen_hybrid.h @@ -131,7 +131,6 @@ class CodeGenHybrid : void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; - void VisitStmt_(const ProducerConsumerNode* op) override; /*! * \brief Print Type represetnation of type t. * \param t The type representation. diff --git a/src/target/llvm/codegen_amdgpu.cc b/src/target/llvm/codegen_amdgpu.cc index 3d1654c24e4f..61121f67d111 100644 --- a/src/target/llvm/codegen_amdgpu.cc +++ b/src/target/llvm/codegen_amdgpu.cc @@ -71,55 +71,53 @@ class CodeGenAMDGPU : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - if (op->new_expr.defined()) { - CHECK_EQ(op->free_function, "nop"); - buf = MakeValue(op->new_expr); - } else { - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the AMD devices - if (info.alignment > 16) { - info.alignment = 16; - } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation in GPU"; + + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the AMD devices + if (info.alignment > 16) { + info.alignment = 16; + } + if (info.scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca( + DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - } - buf = alloca; - } else { - CHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + } + buf = alloca; + } else { + CHECK(info.scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + const unsigned shared_address_space = 3; + llvm::Type* type = llvm::ArrayType::get( + DTypeToLLVMType(op->dtype), constant_size); + // Allocate shared memory in global, address_space = 3 + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", + nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(info.alignment); #endif - buf = global; - } + buf = global; } + buf = builder_->CreatePointerCast( buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc index 604533933b92..14302efe82fc 100644 --- a/src/target/llvm/codegen_llvm.cc +++ b/src/target/llvm/codegen_llvm.cc @@ -1268,10 +1268,7 @@ void CodeGenLLVM::VisitStmt_(const IfThenElseNode* op) { void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - if (op->new_expr.defined()) { - CHECK_EQ(op->free_function, "nop"); - buf = MakeValue(op->new_expr); - } else { + int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) << "Can only handle constant size stack allocation"; @@ -1296,7 +1293,7 @@ void CodeGenLLVM::VisitStmt_(const AllocateNode* op) { } info.alignment = alloca->getAlignment(); buf = alloca; - } + buf = builder_->CreatePointerCast( buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); @@ -1359,9 +1356,6 @@ void CodeGenLLVM::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -void CodeGenLLVM::VisitStmt_(const ProducerConsumerNode* op) { - this->VisitStmt(op->body); -} } // namespace codegen } // namespace tvm #endif // TVM_LLVM_VERSION diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h index e785f3eab275..5c7ca6fb622f 100644 --- a/src/target/llvm/codegen_llvm.h +++ b/src/target/llvm/codegen_llvm.h @@ -150,7 +150,6 @@ class CodeGenLLVM : void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; - void VisitStmt_(const ProducerConsumerNode* op) override; protected: /*! \brief The storage information */ diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc index 48c7968fb12c..40dc653f742b 100644 --- a/src/target/llvm/codegen_nvptx.cc +++ b/src/target/llvm/codegen_nvptx.cc @@ -48,55 +48,53 @@ class CodeGenNVPTX : public CodeGenLLVM { void VisitStmt_(const AllocateNode* op) final { CHECK(!is_zero(op->condition)); llvm::Value* buf = nullptr; - if (op->new_expr.defined()) { - CHECK_EQ(op->free_function, "nop"); - buf = MakeValue(op->new_expr); - } else { - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation in GPU"; - StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; - if (constant_size % 4 == 0 && info.alignment == 0) { - info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); - } - // maximum necessary alignment in the NV devices - if (info.alignment > 16) { - info.alignment = 16; - } - if (info.scope.rank == runtime::StorageRank::kLocal) { - // const int local_address_space = 5; - // TODO(tqchen): for higher version of LLVM, local address space can be set. - llvm::AllocaInst* alloca = WithFunctionEntry([&]() { - return builder_->CreateAlloca( - DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); - }); - if (alloca->getAlignment() < static_cast(info.alignment)) { + + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation in GPU"; + StorageInfo& info = alloc_storage_info_[op->buffer_var.get()]; + if (constant_size % 4 == 0 && info.alignment == 0) { + info.alignment = GetTempAllocaAlignment(op->dtype, constant_size); + } + // maximum necessary alignment in the NV devices + if (info.alignment > 16) { + info.alignment = 16; + } + + if (info.scope.rank == runtime::StorageRank::kLocal) { + // const int local_address_space = 5; + // TODO(tqchen): for higher version of LLVM, local address space can be set. + llvm::AllocaInst* alloca = WithFunctionEntry([&]() { + return builder_->CreateAlloca( + DTypeToLLVMType(op->dtype), ConstInt32(constant_size)); + }); + if (alloca->getAlignment() < static_cast(info.alignment)) { #if TVM_LLVM_VERSION >= 100 - alloca->setAlignment(llvm::Align(info.alignment)); + alloca->setAlignment(llvm::Align(info.alignment)); #else - alloca->setAlignment(info.alignment); + alloca->setAlignment(info.alignment); #endif - } - buf = alloca; - } else { - CHECK(info.scope.rank == runtime::StorageRank::kShared) - << "Can only allocate shared or local memory inside kernel"; - // Shared memory: address space == 3 - const unsigned shared_address_space = 3; - llvm::Type* type = llvm::ArrayType::get( - DTypeToLLVMType(op->dtype), constant_size); - // Allocate shared memory in global, address_space = 3 - llvm::GlobalVariable *global = new llvm::GlobalVariable( - *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", - nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); + } + buf = alloca; + } else { + CHECK(info.scope.rank == runtime::StorageRank::kShared) + << "Can only allocate shared or local memory inside kernel"; + // Shared memory: address space == 3 + const unsigned shared_address_space = 3; + llvm::Type* type = llvm::ArrayType::get( + DTypeToLLVMType(op->dtype), constant_size); + // Allocate shared memory in global, address_space = 3 + llvm::GlobalVariable *global = new llvm::GlobalVariable( + *module_, type, false, llvm::GlobalValue::PrivateLinkage, 0, ".shared", + nullptr, llvm::GlobalValue::NotThreadLocal, shared_address_space); #if TVM_LLVM_VERSION >= 100 - global->setAlignment(llvm::Align(info.alignment)); + global->setAlignment(llvm::Align(info.alignment)); #else - global->setAlignment(info.alignment); + global->setAlignment(info.alignment); #endif - buf = global; - } + buf = global; } + buf = builder_->CreatePointerCast( buf, DTypeToLLVMType(op->dtype)->getPointerTo( buf->getType()->getPointerAddressSpace())); diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc index ac13f8a50091..6e7784c81f85 100644 --- a/src/target/source/codegen_c.cc +++ b/src/target/source/codegen_c.cc @@ -814,14 +814,7 @@ void CodeGenC::VisitStmt_(const LetStmtNode* op) { void CodeGenC::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); - if (op->new_expr.defined()) { - // Prefer global static allocation for the program - CHECK_EQ(op->free_function, "nop"); - std::string new_data = PrintExpr(op->new_expr); - this->PrintIndent(); - PrintType(op->dtype, stream); - stream << "* "<< vid << '=' << new_data << ";\n"; - } else { + this->PrintIndent(); int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) @@ -833,7 +826,7 @@ void CodeGenC::VisitStmt_(const AllocateNode* op) { PrintType(op->dtype, stream); stream << ' '<< vid << '[' << constant_size << "];\n"; - } + RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } @@ -942,10 +935,6 @@ void CodeGenC::VisitStmt_(const EvaluateNode* op) { } } -void CodeGenC::VisitStmt_(const ProducerConsumerNode* op) { - PrintStmt(op->body); -} - void CodeGenC::PrintVecElemLoadExpr( DataType t, int i, const std::string& value, std::ostream& os) { CHECK_GT(t.lanes(), 1); diff --git a/src/target/source/codegen_c.h b/src/target/source/codegen_c.h index 49139de2fd1c..db655beded02 100644 --- a/src/target/source/codegen_c.h +++ b/src/target/source/codegen_c.h @@ -153,7 +153,6 @@ class CodeGenC : void VisitStmt_(const AssertStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; - void VisitStmt_(const ProducerConsumerNode* op) override; /*! * Print Type represetnation of type t. * \param t The type representation. diff --git a/src/target/source/codegen_cuda.cc b/src/target/source/codegen_cuda.cc index c7971cef1bf6..02b5b413562e 100644 --- a/src/target/source/codegen_cuda.cc +++ b/src/target/source/codegen_cuda.cc @@ -514,51 +514,44 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) { void CodeGenCUDA::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); std::string vid = AllocVarID(op->buffer_var.get()); - if (op->new_expr.defined()) { - // Prefer global static allocation for the program - CHECK_EQ(op->free_function, "nop"); - std::string new_data = PrintExpr(op->new_expr); - this->PrintIndent(); - PrintType(op->dtype, stream); - stream << "* "<< vid << '=' << new_data << ";\n"; - } else { - this->PrintIndent(); - int32_t constant_size = op->constant_allocation_size(); - CHECK_GT(constant_size, 0) - << "Can only handle constant size stack allocation for now"; - const VarNode* buffer = op->buffer_var.as(); - std::string scope = alloc_storage_scope_.at(buffer); - if (scope.find("wmma.") == 0) { - if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Int(8) || - op->dtype == DataType::UInt(8) || - op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) - << "Matrix_a and matrix_b only support half or char or unsigned char " - << "or uint4 or int4 or int1 type for now"; - } else { - CHECK(op->dtype == DataType::Float(16) || - op->dtype == DataType::Float(32) || - op->dtype == DataType::Int(32)) - << "Accumulator only support half, float and int type for now"; - } - constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); - PrintWmmaScope(scope, op->dtype, buffer, stream); + + this->PrintIndent(); + int32_t constant_size = op->constant_allocation_size(); + CHECK_GT(constant_size, 0) + << "Can only handle constant size stack allocation for now"; + const VarNode* buffer = op->buffer_var.as(); + std::string scope = alloc_storage_scope_.at(buffer); + if (scope.find("wmma.") == 0) { + if (scope == "wmma.matrix_a" || scope == "wmma.matrix_b") { + CHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Int(8) || + op->dtype == DataType::UInt(8) || + op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) + << "Matrix_a and matrix_b only support half or char or unsigned char " + << "or uint4 or int4 or int1 type for now"; } else { - PrintStorageScope(scope, stream); - stream << ' '; - PrintType(op->dtype, stream); - } - if ((op->dtype == DataType::Int(4) || - op->dtype == DataType::UInt(4) || - op->dtype == DataType::Int(1)) && scope == "shared") { - constant_size = constant_size / (32 / op->dtype.bits()); + CHECK(op->dtype == DataType::Float(16) || + op->dtype == DataType::Float(32) || + op->dtype == DataType::Int(32)) + << "Accumulator only support half, float and int type for now"; } - stream << ' '<< vid << '[' - << constant_size << "];\n"; + constant_size = GetWmmaFragmentSize(scope, buffer, constant_size); + PrintWmmaScope(scope, op->dtype, buffer, stream); + } else { + PrintStorageScope(scope, stream); + stream << ' '; + PrintType(op->dtype, stream); + } + if ((op->dtype == DataType::Int(4) || + op->dtype == DataType::UInt(4) || + op->dtype == DataType::Int(1)) && scope == "shared") { + constant_size = constant_size / (32 / op->dtype.bits()); } + stream << ' '<< vid << '[' + << constant_size << "];\n"; + RegisterHandleType(op->buffer_var.get(), op->dtype); this->PrintStmt(op->body); } diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index bfe21b024426..1d8004e9938f 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -586,7 +586,6 @@ void CodeGenSPIRV::VisitStmt_(const IfThenElseNode* op) { void CodeGenSPIRV::VisitStmt_(const AllocateNode* op) { CHECK(!is_zero(op->condition)); - CHECK(!op->new_expr.defined()); CHECK(!op->dtype.is_handle()); int32_t constant_size = op->constant_allocation_size(); CHECK_GT(constant_size, 0) @@ -659,9 +658,5 @@ void CodeGenSPIRV::VisitStmt_(const EvaluateNode* op) { MakeValue(op->value); } -void CodeGenSPIRV::VisitStmt_(const ProducerConsumerNode* op) { - this->VisitStmt(op->body); -} - } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.h b/src/target/spirv/codegen_spirv.h index edcee20f173f..f50760711dec 100644 --- a/src/target/spirv/codegen_spirv.h +++ b/src/target/spirv/codegen_spirv.h @@ -99,7 +99,6 @@ class CodeGenSPIRV: void VisitStmt_(const LetStmtNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; - void VisitStmt_(const ProducerConsumerNode* op) override; protected: /*! \brief The storage information */ diff --git a/src/target/stackvm/codegen_stackvm.cc b/src/target/stackvm/codegen_stackvm.cc index 383aaf38cea3..b28b6a1f5fb4 100644 --- a/src/target/stackvm/codegen_stackvm.cc +++ b/src/target/stackvm/codegen_stackvm.cc @@ -156,16 +156,7 @@ void CodeGenStackVM::VisitStmt_(const StoreNode* op) { } void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { - CHECK(!is_zero(op->condition)); - int vid = AllocVarID(op->buffer_var.get()); - if (op->new_expr.defined()) { - // Prefer global static allocation for the program - CHECK_EQ(op->free_function, "nop"); - this->Push(op->new_expr); - this->PushOp(StackVM::STORE_HEAP, vid); - } else { - LOG(FATAL) << "Dynamic allocation not supported"; - } + LOG(FATAL) << "Dynamic allocation not supported"; } void CodeGenStackVM::VisitExpr_(const CallNode* op) { @@ -410,10 +401,6 @@ void CodeGenStackVM::VisitExpr_(const NotNode* op) { this->PushOp(StackVM::NOT); } -void CodeGenStackVM::VisitStmt_(const ProducerConsumerNode* op) { - this->Push(op->body); -} - void CodeGenStackVM::VisitStmt_(const ForNode* op) { CHECK(is_zero(op->min)); int vid = this->AllocVarID(op->loop_var.get()); diff --git a/src/target/stackvm/codegen_stackvm.h b/src/target/stackvm/codegen_stackvm.h index fd370d285ea8..31036822649d 100644 --- a/src/target/stackvm/codegen_stackvm.h +++ b/src/target/stackvm/codegen_stackvm.h @@ -147,7 +147,6 @@ class CodeGenStackVM void VisitStmt_(const AssertStmtNode* op) final; void VisitStmt_(const EvaluateNode* op) final; void VisitStmt_(const SeqStmtNode* op) final; - void VisitStmt_(const ProducerConsumerNode* op) final; private: bool debug_{false}; diff --git a/src/te/schedule/schedule_ops.cc b/src/te/schedule/schedule_ops.cc index bec677a03228..57b637df0570 100644 --- a/src/te/schedule/schedule_ops.cc +++ b/src/te/schedule/schedule_ops.cc @@ -43,9 +43,6 @@ Stmt MakePipeline(const Stage& s, Stmt consumer, bool debug_keep_trivial_loop) { Stmt producer = s->op->BuildProvide(s, dom_map, debug_keep_trivial_loop); - if (producer.defined()) { - producer = ProducerConsumerNode::make(s->op, true, producer); - } if (s->double_buffer) { producer = AttrStmtNode::make( s->op, tir::attr::double_buffer_scope, 1, producer); @@ -53,7 +50,6 @@ Stmt MakePipeline(const Stage& s, Stmt pipeline = producer; if (consumer.defined() && !is_no_op(consumer)) { - consumer = ProducerConsumerNode::make(s->op, false, consumer); pipeline = SeqStmt({producer, consumer}); } pipeline = s->op->BuildRealize(s, dom_map, pipeline); @@ -163,20 +159,6 @@ class InjectScanStep : public StmtMutator { // Replace the init and update's expression by scan's buffer. class SchedulePostProc : public StmtExprMutator { public: - Stmt VisitStmt_(const ProducerConsumerNode* op) final { - auto it = replace_op_.find(op->func.get()); - if (it != replace_op_.end()) { - Stmt body = this->VisitStmt(op->body); - if (it->second.defined()) { - return ProducerConsumerNode::make( - it->second, op->is_producer, body); - } else { - return body; - } - } else { - return StmtExprMutator::VisitStmt_(op); - } - } Stmt VisitStmt_(const LetStmtNode* op) final { if (!HasSideEffect(op->value)) { var_value_[op->var.get()] = this->VisitExpr(op->value); diff --git a/src/tir/analysis/verify_memory.cc b/src/tir/analysis/verify_memory.cc index 2a87b2e3e271..8e684e966770 100644 --- a/src/tir/analysis/verify_memory.cc +++ b/src/tir/analysis/verify_memory.cc @@ -39,8 +39,8 @@ namespace { * threads, CPU code is generated that tries to access GPU memory, * which is illegal. * - * This pass performs such verification by checking if all Producer/Consumer - * with memory accesses are bound with threads when device type is GPU. + * This pass performs such verification by checking if all + * memory accesses are bound with threads when device type is GPU. */ class MemoryAccessVerifier final : protected StmtExprVisitor { public: @@ -94,12 +94,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { } } - void VisitStmt_(const ProducerConsumerNode* op) final { - EnterProducerConsumer(op); - StmtExprVisitor::VisitStmt_(op); - ExitProducerConsumer(); - } - void VisitExpr_(const LoadNode* op) final { HandleLoadStoreToVariable(op->buffer_var); return StmtExprVisitor::VisitExpr_(op); @@ -138,11 +132,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { // We skip the access within thread env. if (InThreadEnv()) return; - // We only check access within a producer/consumer. - // Because for load/store out side of producer/consumer, - // they don't have to be in thread env to stay legal (e.g. Load of args). - if (!InProducerConsumer()) return; - // We only handle the variable from function argument. // If it does not come from args, then it could be allocated internally, // it may possibly be in host or device address space. @@ -158,10 +147,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { bool InThreadEnv() const { return in_thread_env_; } void EnterThreadEnv() { in_thread_env_ = true; } void ExitThreadEnv() { in_thread_env_ = false; } - bool InProducerConsumer() const { return pc_ != nullptr; } - const ProducerConsumerNode *GetCurrentProducerConsumer() const { return pc_; } - void EnterProducerConsumer(const ProducerConsumerNode *pc) { this->pc_ = pc; } - void ExitProducerConsumer() { pc_ = nullptr; } void SetFailure() { failure_ = true; } //@} @@ -180,7 +165,6 @@ class MemoryAccessVerifier final : protected StmtExprVisitor { /// Status of visitor //@{ bool in_thread_env_{false}; - const ProducerConsumerNode *pc_{nullptr}; bool failure_{false}; ///< If the verification fails (i.e. has illegal access) //@} tir::PrimFunc func_{nullptr}; ///< Function to be verified. @@ -197,13 +181,18 @@ void VerifyMemory(const IRModule& mod) { auto target = func->GetAttr(tvm::attr::kTarget); CHECK(target.defined()) << "LowerWarpMemory: Require the target attribute"; - MemoryAccessVerifier v(func, target.value()->device_type); - v.Run(); - if (v.Failed()) { - LOG(FATAL) + + if (func->GetAttr( + tvm::attr::kCallingConv, + Integer(CallingConv::kDefault)) == CallingConv::kDefault) { + MemoryAccessVerifier v(func, target.value()->device_type); + v.Run(); + if (v.Failed()) { + LOG(FATAL) << "ValueError: Direct host side access to device memory is detected." << " Did you forget to bind?\n" << func; + } } } } diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 705fe7bdf26e..1f6a7dd027ea 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -82,20 +82,6 @@ TVM_REGISTER_GLOBAL("tir.AssertStmt") } }); -Stmt ProducerConsumerNode::make(FunctionRef func, bool is_producer, Stmt body) { - CHECK(body.defined()); - - ObjectPtr node = make_object(); - node->func = std::move(func); - node->is_producer = is_producer; - node->body = std::move(body); - return Stmt(node); -} - -TVM_REGISTER_GLOBAL("tir.ProducerConsumer") -.set_body_typed(ProducerConsumerNode::make); - - Stmt ForNode::make(Var loop_var, PrimExpr min, PrimExpr extent, @@ -184,9 +170,7 @@ Stmt AllocateNode::make(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, - Stmt body, - PrimExpr new_expr, - std::string free_function) { + Stmt body) { for (size_t i = 0; i < extents.size(); ++i) { CHECK(extents[i].defined()); CHECK(extents[i].dtype().is_scalar()); @@ -201,8 +185,6 @@ Stmt AllocateNode::make(Var buffer_var, node->extents = std::move(extents); node->condition = std::move(condition); node->body = std::move(body); - node->new_expr = std::move(new_expr); - node->free_function = std::move(free_function); return Stmt(node); } @@ -381,22 +363,6 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->Print(op->body); }); -TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) -.set_dispatch([](const ObjectRef& node, ReprPrinter* p) { - auto* op = static_cast(node.get()); - if (op->is_producer) { - p->PrintIndent(); - p->stream << "produce " << op->func->func_name() << " {\n"; - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - p->PrintIndent(); - p->stream << "}\n"; - } else { - p->Print(op->body); - } - }); - std::ostream &operator<<(std::ostream& out, ForType type) { // NOLINT(*) switch (type) { case ForType::Serial: @@ -615,7 +581,6 @@ TVM_REGISTER_NODE_TYPE(CallNode); TVM_REGISTER_NODE_TYPE(LetNode); TVM_REGISTER_NODE_TYPE(LetStmtNode); TVM_REGISTER_NODE_TYPE(AssertStmtNode); -TVM_REGISTER_NODE_TYPE(ProducerConsumerNode); TVM_REGISTER_NODE_TYPE(ForNode); TVM_REGISTER_NODE_TYPE(StoreNode); TVM_REGISTER_NODE_TYPE(ProvideNode); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 96fc4354aa94..ed3c2c75ef47 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -149,9 +149,6 @@ void StmtVisitor::VisitStmt_(const AllocateNode* op) { VisitArray(op->extents, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitStmt(op->body); this->VisitExpr(op->condition); - if (op->new_expr.defined()) { - this->VisitExpr(op->new_expr); - } } void StmtVisitor::VisitStmt_(const StoreNode* op) { @@ -180,10 +177,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProducerConsumerNode* op) { - this->VisitStmt(op->body); -} - void StmtVisitor::VisitStmt_(const ProvideNode* op) { VisitArray(op->args, [this](const PrimExpr& e) { this->VisitExpr(e); }); this->VisitExpr(op->value); @@ -291,21 +284,16 @@ Stmt StmtMutator::VisitStmt_(const AllocateNode* op) { Array extents = Internal::Mutate(this, op->extents); Stmt body = this->VisitStmt(op->body); PrimExpr condition = this->VisitExpr(op->condition); - PrimExpr new_expr; - if (op->new_expr.defined()) { - new_expr = this->VisitExpr(op->new_expr); - } + if (extents.same_as(op->extents) && body.same_as(op->body) && - condition.same_as(op->condition) && - new_expr.same_as(op->new_expr)) { + condition.same_as(op->condition)) { return GetRef(op); } else { auto n = CopyOnWrite(op); n->extents = std::move(extents); n->body = std::move(body); n->condition = std::move(condition); - n->new_expr = std::move(new_expr); return Stmt(n); } } @@ -475,17 +463,6 @@ Stmt StmtMutator::VisitStmt_(const AssertStmtNode* op) { } } -Stmt StmtMutator::VisitStmt_(const ProducerConsumerNode* op) { - Stmt body = this->VisitStmt(op->body); - if (body.same_as(op->body)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->body = std::move(body); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { PrimExpr value = this->VisitExpr(op->value); if (value.same_as(op->value)) { diff --git a/src/tir/pass/inject_virtual_thread.cc b/src/tir/pass/inject_virtual_thread.cc index 99e11491c4b1..e9c403ca5cb5 100644 --- a/src/tir/pass/inject_virtual_thread.cc +++ b/src/tir/pass/inject_virtual_thread.cc @@ -129,9 +129,6 @@ class VarTouchedAnalysis : public StmtVisitor { tc(op->extents[i]); } tc.VisitExpr(op->condition); - if (op->new_expr.defined()) { - tc(op->new_expr); - } Record(op->buffer_var.get(), tc); this->VisitStmt(op->body); } @@ -371,9 +368,6 @@ class VTInjector : public StmtExprMutator { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { - if (op->new_expr.defined() && !vt_loop_injected_) { - return InjectVTLoop(GetRef(op), true); - } PrimExpr condition = this->VisitExpr(op->condition); if (visit_touched_var_ && !vt_loop_injected_) { return InjectVTLoop(GetRef(op), true); @@ -419,8 +413,7 @@ class VTInjector : public StmtExprMutator { } else { return AllocateNode::make( op->buffer_var, op->dtype, - extents, condition, body, - op->new_expr, op->free_function); + extents, condition, body); } } diff --git a/src/tir/pass/lift_attr_scope.cc b/src/tir/pass/lift_attr_scope.cc index 2874ac2b19de..9aa037feb460 100644 --- a/src/tir/pass/lift_attr_scope.cc +++ b/src/tir/pass/lift_attr_scope.cc @@ -58,8 +58,7 @@ class AttrScopeLifter : public StmtMutator { attr_value_ = PrimExpr(); return AllocateNode::make( op->buffer_var, op->dtype, - op->extents, op->condition, body, - op->new_expr, op->free_function); + op->extents, op->condition, body); } else { return stmt; } diff --git a/src/tir/pass/remove_no_op.cc b/src/tir/pass/remove_no_op.cc index 3b9f823517ac..181a8c483e4e 100644 --- a/src/tir/pass/remove_no_op.cc +++ b/src/tir/pass/remove_no_op.cc @@ -79,11 +79,7 @@ class NoOpRemover : public StmtMutator { op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const ProducerConsumerNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - return is_no_op(op->body) ? op->body : stmt; - } + Stmt VisitStmt_(const RealizeNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); diff --git a/src/tir/pass/ssa.cc b/src/tir/pass/ssa.cc index 833702e8202e..daef32c01bdb 100644 --- a/src/tir/pass/ssa.cc +++ b/src/tir/pass/ssa.cc @@ -158,7 +158,7 @@ class IRConvertSSA final : public StmtExprMutator { op = stmt.as(); return AllocateNode::make( new_var, op->dtype, op->extents, op->condition, - op->body, op->new_expr, op->free_function); + op->body); } else { defined_.insert(v.get()); return StmtExprMutator::VisitStmt_(op); diff --git a/src/tir/pass/vectorize_loop.cc b/src/tir/pass/vectorize_loop.cc index d62bd1f2584e..b73587db2ab6 100644 --- a/src/tir/pass/vectorize_loop.cc +++ b/src/tir/pass/vectorize_loop.cc @@ -403,10 +403,6 @@ class Vectorizer : public StmtExprMutator { } // Allocate Stmt VisitStmt_(const AllocateNode* op) final { - if (op->new_expr.defined()) { - LOG(WARNING) << "Cannot vectorize with new expr"; - return Scalarize(GetRef(op)); - } PrimExpr condition = this->VisitExpr(op->condition); if (condition.dtype().is_vector()) { LOG(WARNING) << "Cannot handle vector extent in alloc "; @@ -429,8 +425,7 @@ class Vectorizer : public StmtExprMutator { body = this->VisitStmt(body); return AllocateNode::make( op->buffer_var, op->dtype, - extents, condition, body, - op->new_expr, op->free_function); + extents, condition, body); } // scalarize the statment Stmt Scalarize(Stmt stmt) { diff --git a/src/tir/pass/verify_gpu_code.cc b/src/tir/pass/verify_gpu_code.cc index f05423b7dca5..70d909a859cc 100644 --- a/src/tir/pass/verify_gpu_code.cc +++ b/src/tir/pass/verify_gpu_code.cc @@ -56,29 +56,6 @@ class GPUCodeVerifier : public StmtVisitor { return valid_; } - void VisitStmt_(const ProducerConsumerNode* op) final { - if (nest_level_ == 0) { - // enter a new kernel, reset statistics - Reset_(); - } - - if (op->is_producer) { - nest_level_++; - StmtVisitor::VisitStmt_(op); - nest_level_--; - } else { - StmtVisitor::VisitStmt_(op); - } - - if (nest_level_ == 0) { - // exit a kernel, check the validity - valid_ &= thread_per_block_ <= max_threads_per_block_; - - valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; - valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; - } - } - void VisitStmt_(const AllocateNode* op) final { StmtVisitor::VisitStmt_(op); // visit an allocation of a buffer in shared memory, record its size @@ -99,7 +76,13 @@ class GPUCodeVerifier : public StmtVisitor { } else if (op_value == "shared") { visited_shared_buffers_.insert(op->node.as()); } + StmtVisitor::VisitStmt_(op); } else if (op->attr_key == attr::thread_extent) { + if (nest_level_ == 0) { + // enter a new kernel, reset statistics + Reset_(); + } + Var var = op->node.as()->var; const auto *extent = op->value.as(); CHECK(extent); @@ -133,8 +116,21 @@ class GPUCodeVerifier : public StmtVisitor { } } } + + nest_level_++; + StmtVisitor::VisitStmt_(op); + nest_level_--; + + if (nest_level_ == 0) { + // exit a kernel, check the validity + valid_ &= thread_per_block_ <= max_threads_per_block_; + + valid_ &= local_memory_per_block_ <= max_local_memory_per_block_; + valid_ &= shared_memory_per_block_ <= max_shared_memory_per_block_; + } + } else { + StmtVisitor::VisitStmt_(op); } - StmtVisitor::VisitStmt_(op); } private: diff --git a/src/tir/transforms/lower_custom_datatypes.cc b/src/tir/transforms/lower_custom_datatypes.cc index 6cf9e3adce96..ce81528b8b35 100644 --- a/src/tir/transforms/lower_custom_datatypes.cc +++ b/src/tir/transforms/lower_custom_datatypes.cc @@ -79,9 +79,9 @@ class CustomDatatypesLowerer : public StmtExprMutator { if (toBeLowered) { auto new_allocate_type = DataType::UInt(allocate->dtype.bits(), allocate->dtype.lanes()); - return AllocateNode::make(allocate->buffer_var, new_allocate_type, allocate->extents, - allocate->condition, allocate->body, allocate->new_expr, - allocate->free_function); + return AllocateNode::make( + allocate->buffer_var, new_allocate_type, allocate->extents, + allocate->condition, allocate->body); } return stmt; } diff --git a/src/tir/transforms/lower_device_storage_access_info.cc b/src/tir/transforms/lower_device_storage_access_info.cc index 2438da958a70..e7f81ed929b9 100644 --- a/src/tir/transforms/lower_device_storage_access_info.cc +++ b/src/tir/transforms/lower_device_storage_access_info.cc @@ -51,12 +51,13 @@ class StorageAccessInfoLower : public StmtExprMutator { ++it->second.alloc_count; CHECK_LE(it->second.alloc_count, 1) << "Double allocation of " << it->second.scope.to_string(); + if (info->head_address.defined()) { - return AllocateNode::make( - op->buffer_var, op->dtype, op->extents, op->condition, - op->body, info->head_address, "nop"); + return LetStmtNode::make( + op->buffer_var, info->head_address, op->body); + } else { + return op->body; } - return op->body; } else { return stmt; } @@ -110,10 +111,10 @@ class StorageAccessInfoLower : public StmtExprMutator { } PrimExpr MakeTaggedAccessPtr(DataType ptr_type, - Var buffer_var, - DataType dtype, - PrimExpr offset, - const MemoryInfo& info) { + Var buffer_var, + DataType dtype, + PrimExpr offset, + const MemoryInfo& info) { if (ptr_type.is_handle()) { CHECK(info->head_address.defined()) << buffer_var << " is not adddressable."; diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 58c966b21711..71ba468a950f 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -93,7 +93,6 @@ class BuiltinLower : public StmtExprMutator { // Lower allocate to device allocate when needed. Stmt stmt = StmtExprMutator::VisitStmt_(op); op = stmt.as(); - if (op->new_expr.defined()) return stmt; // Get constant allocation bound. int64_t dev_type; int64_t nbytes = GetVectorBytes(op->dtype); diff --git a/tests/python/unittest/test_te_build_lower.py b/tests/python/unittest/test_te_build_lower.py index 3ad1747d3ecc..442c4fed7b2f 100644 --- a/tests/python/unittest/test_te_build_lower.py +++ b/tests/python/unittest/test_te_build_lower.py @@ -49,8 +49,8 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c], simple_mode=True) - assert isinstance(stmt.body.body.body.body, tvm.tir.stmt.IfThenElse) - assert str(stmt.body.body.body.body).count("likely") == 1 + assert isinstance(stmt.body.body.body, tvm.tir.stmt.IfThenElse) + assert str(stmt.body.body.body).count("likely") == 1 if __name__ == "__main__": test_lower_rfactor() diff --git a/tests/python/unittest/test_te_hybrid_script.py b/tests/python/unittest/test_te_hybrid_script.py index f1e89672b818..b525d018340d 100644 --- a/tests/python/unittest/test_te_hybrid_script.py +++ b/tests/python/unittest/test_te_hybrid_script.py @@ -366,7 +366,7 @@ def foo(a): c = foo(a) s = te.create_schedule(c.op) ir = tvm.lower(s, [a, c], simple_mode=True) - assert not isinstance(ir, tvm.tir.AttrStmt) + func, ins, outs = run_and_check(foo, [a], target='cuda') run_and_check(func, ins, outs=outs, target='cuda') @@ -731,8 +731,6 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c], simple_mode=True) - assert isinstance(ir, tvm.tir.ProducerConsumer) - ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -754,8 +752,6 @@ def outer_product(a, b): sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c], simple_mode=True) - assert isinstance(ir, tvm.tir.ProducerConsumer) - ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/unittest/test_te_schedule.py b/tests/python/unittest/test_te_schedule.py index 5a4c02173f39..c9b422f7f0a4 100644 --- a/tests/python/unittest/test_te_schedule.py +++ b/tests/python/unittest/test_te_schedule.py @@ -284,10 +284,10 @@ def intrin_func(ins, outs, sp): C = te.compute((10,10), lambda i, j: intrin(i*i, A[i, j], i+j), name="C") s = te.create_schedule(C.op) stmt = tvm.lower(s, [A, C], simple_mode=True) - assert isinstance(stmt.body.body.body, tvm.tir.Evaluate) - assert len(stmt.body.body.body.value.args) == 5 - assert str(stmt.body.body.body.value.args[3]) == "(i*i)" - assert str(stmt.body.body.body.value.args[4]) == "(i + j)" + assert isinstance(stmt.body.body, tvm.tir.Evaluate) + assert len(stmt.body.body.value.args) == 5 + assert str(stmt.body.body.value.args[3]) == "(i*i)" + assert str(stmt.body.body.value.args[4]) == "(i + j)" if __name__ == "__main__": test_singleton() diff --git a/tests/python/unittest/test_te_schedule_ops.py b/tests/python/unittest/test_te_schedule_ops.py index 4e27ad3f2a58..3e521ab07023 100644 --- a/tests/python/unittest/test_te_schedule_ops.py +++ b/tests/python/unittest/test_te_schedule_ops.py @@ -72,7 +72,6 @@ def test_schedule_scan(): s = te.create_schedule(res.op) s = s.normalize() ir = tvm.lower(s, [s_state], simple_mode=True) - assert not hasattr(ir.body.body.body.body[1].body.body[1].body, "condition") bounds = tvm.te.schedule.InferBound(s) assert(bounds[res.op.scan_axis].min.value == 1) stmt = tvm.te.schedule.ScheduleOps(s, bounds) @@ -557,7 +556,6 @@ def collect_visit(stmt, f): return ret def visit_stmt(op): - print(op) if (isinstance(op, tvm.tir.Allocate)): return op.extents[0].value == 97 return False @@ -593,4 +591,4 @@ def visit_stmt(op): test_reduction_and_dummy_fuse_split() test_schedule_compute_inline() test_local_stage_predicate() - test_local_stage_predicate2() \ No newline at end of file + test_local_stage_predicate2() diff --git a/tests/python/unittest/test_te_schedule_tensorize.py b/tests/python/unittest/test_te_schedule_tensorize.py index 7dceaefd9761..dafffed9bd44 100644 --- a/tests/python/unittest/test_te_schedule_tensorize.py +++ b/tests/python/unittest/test_te_schedule_tensorize.py @@ -327,8 +327,8 @@ def intrin_func(ins, outs): stmt = tvm.te.schedule.ScheduleOps(s, dom_map) # The loop that we tried to tensorize still exists in the code # That means tensorize didn't work as expected - assert isinstance(stmt.body.body.body, tvm.tir.For) - assert stmt.body.body.body.loop_var.name == C.op.axis[0].var.name + assert isinstance(stmt.body.body, tvm.tir.For) + assert stmt.body.body.loop_var.name == C.op.axis[0].var.name diff --git a/tests/python/unittest/test_te_tensor.py b/tests/python/unittest/test_te_tensor.py index 762b3fe75180..55edd1c9958b 100644 --- a/tests/python/unittest/test_te_tensor.py +++ b/tests/python/unittest/test_te_tensor.py @@ -129,7 +129,7 @@ def intrin_func(ins, outs): s = te.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body, tvm.tir.Evaluate) + assert isinstance(stmt.body, tvm.tir.Evaluate) def test_tensor_compute2(): M = 2048 @@ -172,8 +172,8 @@ def intrin_func(ins, outs): s = te.create_schedule(C.op) stmt = tvm.lower(s, [A, B, C], simple_mode=True) - assert isinstance(stmt.body.body.body[0], tvm.tir.Evaluate) - assert isinstance(stmt.body.body.body[1].body, tvm.tir.Evaluate) + assert isinstance(stmt.body.body[0], tvm.tir.Evaluate) + assert isinstance(stmt.body.body[1].body, tvm.tir.Evaluate) def test_tensor_scan(): m = te.size_var("m") diff --git a/tests/python/unittest/test_tir_constructor.py b/tests/python/unittest/test_tir_constructor.py index 9edaf92d0db7..7a03e48e2270 100644 --- a/tests/python/unittest/test_tir_constructor.py +++ b/tests/python/unittest/test_tir_constructor.py @@ -148,10 +148,6 @@ def test_stmt_constructor(): assert isinstance(x, tvm.tir.AssertStmt) assert x.body == nop - x = tvm.tir.ProducerConsumer(None, True, nop) - assert isinstance(x, tvm.tir.ProducerConsumer) - assert x.body == nop - x = tvm.tir.For(te.var("x"), 0, 10, 0, 0, nop) assert isinstance(x, tvm.tir.For) assert x.min.value == 0 diff --git a/tests/python/unittest/test_tir_pass_loop_partition.py b/tests/python/unittest/test_tir_pass_loop_partition.py index 0818c0ed0fe2..1256d8bbd4fc 100644 --- a/tests/python/unittest/test_tir_pass_loop_partition.py +++ b/tests/python/unittest/test_tir_pass_loop_partition.py @@ -23,14 +23,6 @@ def collect_visit(stmt, f): tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : ret.append(f(x))) return ret -def find_top_produce(stmt): - def f(x, ret): - if isinstance(x, tvm.tir.ProducerConsumer): - ret.append(x) - ret = [] - tvm.tir.ir_pass.PostOrderVisit(stmt, lambda x : f(x, ret)) - return ret[-1] - def lower(sch, args): binds = {} arg_list = [] @@ -65,8 +57,8 @@ def test_basic(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body[0])) - assert('if' in str(stmt.body.body.body[1])) + assert('if' not in str(stmt.body.body[0])) + assert('if' in str(stmt.body.body[1])) def test_const_loop(): n = 21 @@ -81,7 +73,7 @@ def test_const_loop(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.tir.ir_pass.LoopPartition(stmt, True) stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body[0])) + assert('if' not in str(stmt.body.body[0])) def test_multi_loop(): ib = tvm.tir.ir_builder.create() @@ -136,7 +128,7 @@ def test_thread_axis(): stmt = tvm.te.schedule.ScheduleOps(s, bounds) stmt = tvm.tir.ir_pass.LoopPartition(stmt, False) stmt = tvm.tir.ir_pass.Simplify(stmt) - assert('if' not in str(stmt.body.body.body[0])) + assert('if' not in str(stmt.body.body[0])) def test_vectorize(): n = te.size_var('n') @@ -156,7 +148,7 @@ def test_vectorize(): s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) stmt = lower(s, [A, B]) - body = stmt.body.body.body.body.body + body = stmt.body.body.body.body assert(x.var.name not in str(body.condition)) assert(any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))) @@ -199,7 +191,7 @@ def test_thread_axis2(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) stmt = lower(s, [A, B]) - for_body = stmt.body.body.body.body.body[0] + for_body = stmt.body.body.body.body[0] assert('threadIdx' not in str(for_body.extent)) def test_everything_during_deduction(): @@ -405,9 +397,7 @@ def test_double_splitting_with_indivisible_factors(): f = tvm.lower(s, [A, C, D], name="fadd1", simple_mode=False) func = tvm.build(f, target=target) - # Find the beginning of the Halide IR corresponding to kernel code - # and make sure it doesn't have an if statements left - top_produce = find_top_produce(f["fadd1"].body) + top_produce = f["fadd1"].body assert(not any(collect_visit(top_produce, lambda x: isinstance(x, tvm.tir.IfThenElse)))) # check functional correctness of generated code diff --git a/tests/python/unittest/test_tir_transform_narrow_datatype.py b/tests/python/unittest/test_tir_transform_narrow_datatype.py index a5583d5fe0a3..dbf22679c1a0 100644 --- a/tests/python/unittest/test_tir_transform_narrow_datatype.py +++ b/tests/python/unittest/test_tir_transform_narrow_datatype.py @@ -148,7 +148,7 @@ def check(m, target_bits, target_dtype): B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name='B') s = te.create_schedule(B.op) stmt = lower_sch(s, [A, B], target_bits) - assert stmt.body[1].loop_var.dtype == target_dtype + assert stmt[1].loop_var.dtype == target_dtype # i32 -> i32 check(const(64, dtype='int32'), 32, 'int32') diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py index 5924cdd6ca82..8a7798ab7e6f 100644 --- a/vta/python/vta/ir_pass.py +++ b/vta/python/vta/ir_pass.py @@ -60,11 +60,8 @@ def fold_uop_loop(stmt_in): def _fold_outermost_loop(body): stmt = body - while not isinstance(stmt, tvm.tir.For): - if isinstance(stmt, (tvm.tir.ProducerConsumer,)): - stmt = stmt.body - else: - return None, body, None + if not isinstance(stmt, tvm.tir.For): + return None, body, None loop_var = stmt.loop_var gemm_offsets = [None, None, None]