diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index b36f5cd7384d..febdac55d9aa 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -417,13 +417,6 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); void BufferStore(Buffer buffer, PrimExpr value, Array indices, Optional predicate); -/*! - * \brief The prefetch hint for a buffer - * \param buffer The buffer to be prefetched. - * \param bounds The bounds to be prefetched. - */ -void Prefetch(Buffer buffer, Array bounds); - /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 6d93a3a153ad..cb5db7e44f8a 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -29,7 +29,6 @@ #include #include #include -#include namespace tvm { namespace tir { @@ -335,124 +334,6 @@ class BufferRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; -/*! - * \brief Store value into mult-dimensional array that will be read by the consumer - * of the producer. - * - * \note This node only appears in high-level DSLs that are built on top of the TIR. - * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower - * this node before TIR transformations. - * - * \sa DataProducer - */ -class ProducerStoreNode : public StmtNode { - public: - /*! \brief The producer to store the results into. */ - DataProducer producer; - /*! \brief The value to be stored. */ - PrimExpr value; - /*! \brief The index arguments of the function. */ - Array indices; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("producer", &producer); - v->Visit("value", &value); - v->Visit("indices", &indices); - v->Visit("span", &span); - } - - bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const { - return equal(producer, other->producer) && equal(value, other->value) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(producer); - hash_reduce(value); - hash_reduce(indices); - } - - static constexpr const char* _type_key = "tir.ProducerStore"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); -}; - -/*! - * \brief Managed reference to ProducerStoreNode. - * \sa ProducerStoreNode - */ -class ProducerStore : public Stmt { - public: - TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices, - Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode); -}; - -/*! - * \brief Annotate the bounds where the data produced by the producer - * need to be written and read in body. - * We will need to allocate space for the corresponding regions. - * - * \note This node only appears in high-level DSLs that are built on top of the TIR. - * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower - * this node before TIR transformations. - * - * \sa DataProducer - */ -class ProducerRealizeNode : public StmtNode { - public: - /*! \brief The producer that produces the data. */ - DataProducer producer; - /*! \brief Bounds to be realized. */ - Region bounds; - /*! \brief Only realize if condition holds. */ - PrimExpr condition; - /*! \brief The body of realization. */ - Stmt body; - /*! \brief The storage scope associated with this realization. */ - String storage_scope; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("producer", &producer); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - v->Visit("storage_scope", &storage_scope); - v->Visit("span", &span); - } - - bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { - return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body) && - equal(storage_scope, other->storage_scope); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(producer); - hash_reduce(bounds); - hash_reduce(condition); - hash_reduce(body); - hash_reduce(storage_scope); - } - - static constexpr const char* _type_key = "tir.ProducerRealize"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); -}; - -/*! - * \brief Managed reference to ProducerRealizeNode. - * \sa ProducerRealizeNode - */ -class ProducerRealize : public Stmt { - public: - TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - String storage_scope = "", Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode); -}; - /*! * \brief Allocate a buffer that can be used in body. */ @@ -1090,51 +971,6 @@ class While : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode); }; -/*! - * \brief A prefetch hint for a buffer - */ -class PrefetchNode : public StmtNode { - public: - /*! \brief The function to be prefetched. */ - Buffer buffer; - /*! \brief Bounds to be prefetched. */ - Array bounds; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("bounds", &bounds); - v->Visit("span", &span); - } - - bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(bounds, other->bounds); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(bounds); - } - - PrefetchNode() = default; - PrefetchNode(Buffer buffer, Array bounds, Span span = Span()) - : StmtNode(span), buffer(buffer), bounds(bounds) {} - - static constexpr const char* _type_key = "tir.Prefetch"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); -}; - -/*! - * \brief Managed reference to PrefetchNode. - * \sa PrefetchNode - */ -class Prefetch : public Stmt { - public: - TVM_DLL explicit Prefetch(Buffer buffer, Array bounds, Span span = Span()); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode); -}; - /*! * \brief Representing the region of multi-dimensional buffer access. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 141fe710b371..23747a7e936c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -93,9 +93,6 @@ class StmtFunctor { virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -118,9 +115,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); - IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); - IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); @@ -164,9 +158,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProducerStoreNode* op) override; - void VisitStmt_(const ProducerRealizeNode* op) override; - void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const BlockNode* op) override; @@ -265,9 +256,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProducerStoreNode* op) override; - Stmt VisitStmt_(const ProducerRealizeNode* op) override; - Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 1aaeaa034724..c7589f4a19a6 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1316,22 +1316,6 @@ def buffer_store( ) -def prefetch( - buffer: Buffer, # pylint: disable=redefined-outer-name - bounds: List[ir.Range], -) -> None: - """The prefetch hint for a buffer. - - Parameters - ---------- - buffer : Buffer - The buffer to be prefetched. - bounds : List[Range] - The range to be prefetched. - """ - return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member - - def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -2144,7 +2128,6 @@ def wrapped(*args, **kwargs): "launch_thread", "env_thread", "buffer_store", - "prefetch", "evaluate", "boolean", "handle", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 24db80cb651a..5c4a2b91f5d7 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -32,15 +32,14 @@ from .stmt import ( BufferStore, BufferRealize, - ProducerStore, Allocate, AllocateConst, AttrStmt, DeclBuffer, ) -from .stmt import ProducerRealize, SeqStmt -from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list +from .stmt import SeqStmt +from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize from .function import PrimFunc, TensorIntrin, IndexMap diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index a04f80b55e7a..ffb6fd6a7068 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -34,7 +34,7 @@ from tvm.runtime import Object, Scriptable, const, NDArray from . import _ffi_api -from .buffer import Buffer, DataProducer +from .buffer import Buffer from .expr import Var, IterVar @@ -293,42 +293,6 @@ def __init__( ) -@tvm.ffi.register_object("tir.ProducerStore") -class ProducerStore(Stmt): - """ProducerStore node. - - Parameters - ---------- - producer : DataProducer - The data producer. - - value : PrimExpr - The value to be stored. - - indices : list of Expr - The index arguments of the store. - - span : Optional[Span] - The location of the stmt in the source code. - """ - - producer: DataProducer - value: PrimExpr - indices: List[PrimExpr] - span: Optional[Span] - - def __init__( - self, - producer: DataProducer, - value: PrimExpr, - indices: List[PrimExpr], - span: Optional[Span] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ProducerStore, producer, value, indices, span # type: ignore - ) - - @tvm.ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. @@ -511,58 +475,6 @@ def __init__( ) -@tvm.ffi.register_object("tir.ProducerRealize") -class ProducerRealize(Stmt): - """ProducerRealize node. - - Parameters - ---------- - producer : DataProducer - The data producer. - - bounds : List[Range] - The bound of realize - - condition : PrimExpr - The realize condition. - - body : Stmt - The realize body - - storage_scope : str - The storage scope associated with this realization - - span : Optional[Span] - The location of the stmt in the source code. - """ - - producer: DataProducer - bounds: List[Range] - condition: PrimExpr - body: Stmt - storage_scope: str - span: Optional[Span] - - def __init__( - self, - producer: DataProducer, - bounds: List[Range], - condition: PrimExpr, - body: Stmt, - storage_scope: str = "", - span: Optional[Span] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, # type: ignore - producer, - bounds, - condition, - body, - storage_scope, - span, - ) - - @tvm.ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. @@ -644,30 +556,6 @@ def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore -@tvm.ffi.register_object("tir.Prefetch") -class Prefetch(Stmt): - """Prefetch node. - - Parameters - ---------- - buffer : Buffer - The buffer to be prefetched. - - bounds : List[Range] - The bounds to be prefetched. - - span : Optional[Span] - The location of the stmt in the source code. - """ - - buffer: Buffer - bounds: List[Range] - span: Optional[Span] - - def __init__(self, buffer: Buffer, bounds: List[Range], span: Optional[Span] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) # type: ignore - - @tvm.ffi.register_object("tir.BufferRegion") class BufferRegion(Object, Scriptable): """BufferRegion node. diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 2d61ca3e75f5..6f73254ff2ab 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -596,10 +596,6 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices, AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } -void Prefetch(Buffer buffer, Array bounds) { - AddToParent(tvm::tir::Prefetch(buffer, bounds)); -} - DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, Optional data, Optional> strides, Optional elem_offset, String storage_scope, int align, @@ -724,7 +720,6 @@ TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 5e414e90c262..57dd691b8897 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -625,21 +625,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) (*p) << ", "; - } - (*p) << "]"; - (*p) << " ="; - p->Print(op->value); - (*p) << '\n'; - }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); @@ -682,50 +667,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << op->body; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "producer_realize " << op->producer->GetNameHint() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "prefetch " << op->buffer << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 0427c359049b..f1b75066cdba 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -338,34 +338,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc { - ExprDoc prefix = IdDoc(store->producer->GetNameHint()); - prefix = prefix[BufferIndices(store->indices, p->Attr("indices"), d)]; - return AssignDoc(prefix, d->AsDoc(store->value, p->Attr("value")), std::nullopt); - }); - -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { - ExprDoc prefix = IdDoc(stmt->producer->GetNameHint()); - prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)]; - prefix = TIR(d, "ProducerRealize") - ->Call({prefix, d->AsDoc(stmt->condition, p->Attr("condition"))}); - With f(d, stmt); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - return ScopeDoc(std::nullopt, prefix, (*f)->stmts); - }); - TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 1d310c2a5a9f..239d9f67216f 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -199,16 +199,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return StmtBlockDoc((*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { - return ExprStmtDoc(TIR(d, "prefetch") - ->Call({ - d->AsDoc(stmt->buffer, p->Attr("buffer")), - d->AsDoc(stmt->bounds, p->Attr("bounds")), - })); - }); - bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { const tir::Var& buffer_var = allocate->buffer_var; if (const tir::DeclBufferNode* decl_buffer = allocate->body.as()) { @@ -462,7 +452,6 @@ TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 62baf45bc78e..f400ca0d507e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -205,24 +205,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt TVM_REGISTER_NODE_TYPE(WhileNode); -// ProducerStore -ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, - Span span) { - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->value = std::move(value); - node->indices = std::move(indices); - node->span = std::move(span); - data_ = std::move(node); -} - -TVM_FFI_REGISTER_GLOBAL("tir.ProducerStore") - .set_body_typed([](DataProducer producer, PrimExpr value, Array indices, Span span) { - return ProducerStore(producer, value, indices, span); - }); - -TVM_REGISTER_NODE_TYPE(ProducerStoreNode); - // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { @@ -353,49 +335,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt TVM_REGISTER_NODE_TYPE(DeclBufferNode); -// ProducerRealize -ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, String storage_scope, Span span) { - for (size_t i = 0; i < bounds.size(); ++i) { - ICHECK(bounds[i]->min.defined()); - ICHECK(bounds[i]->extent.defined()); - ICHECK(bounds[i]->min.dtype().is_scalar()); - ICHECK(bounds[i]->extent.dtype().is_scalar()); - } - ICHECK(body.defined()); - ICHECK(condition.defined()); - ICHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->bounds = std::move(bounds); - node->condition = std::move(condition); - node->body = std::move(body); - node->span = std::move(span); - node->storage_scope = std::move(storage_scope); - data_ = std::move(node); -} - -TVM_FFI_REGISTER_GLOBAL("tir.ProducerRealize") - .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - String storage_scope, Span span) { - return ProducerRealize(producer, bounds, condition, body, storage_scope, span); - }); - -TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); - -// Prefetch -Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { - data_ = make_object(buffer, bounds, span); -} - -TVM_FFI_REGISTER_GLOBAL("tir.Prefetch") - .set_body_typed([](Buffer buffer, Array bounds, Span span) { - return Prefetch(buffer, bounds, span); - }); - -TVM_REGISTER_NODE_TYPE(PrefetchNode); - // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { bool requires_flattening = std::any_of( diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 85d347172702..4daa0e9a5468 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -94,27 +94,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { - VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); - this->VisitExpr(op->value); -} - -void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { - VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); - this->VisitStmt(op->body); - this->VisitExpr(op->condition); -} - -void StmtVisitor::VisitStmt_(const PrefetchNode* op) { - VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); -} - void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } @@ -395,45 +374,6 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } } -Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { - Array indices = Internal::Mutate(this, op->indices); - PrimExpr value = this->VisitExpr(op->value); - if (indices.same_as(op->indices) && value.same_as(op->value)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->indices = std::move(indices); - n->value = std::move(value); - return Stmt(n); - } -} - -Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { - Region bounds = Internal::Mutate(this, op->bounds); - Stmt body = this->VisitStmt(op->body); - PrimExpr condition = this->VisitExpr(op->condition); - if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->bounds = std::move(bounds); - n->body = std::move(body); - n->condition = std::move(condition); - return Stmt(n); - } -} - -Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { - Region bounds = Internal::Mutate(this, op->bounds); - if (bounds.same_as(op->bounds)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->bounds = std::move(bounds); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 4f5007aedb3f..78cfd004dd4d 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -266,21 +266,6 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const ProducerStoreNode* op, ObjectPath path) { - Visit(op->indices, path->Attr("indices")); - Visit(op->value, path->Attr("value")); -} - -void TIRVisitorWithPath::VisitStmt_(const ProducerRealizeNode* op, ObjectPath path) { - Visit(op->bounds, path->Attr("bounds")); - Visit(op->body, path->Attr("body")); - Visit(op->condition, path->Attr("condition")); -} - -void TIRVisitorWithPath::VisitStmt_(const PrefetchNode* op, ObjectPath path) { - Visit(op->bounds, path->Attr("bounds")); -} - void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, ObjectPath path) { Visit(op->seq, path->Attr("seq")); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 61441541da32..6b1cd8ace487 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -110,9 +110,6 @@ class TIRVisitorWithPath : protected ExprFunctorbody) ? MakeEvaluate({op->min, op->extent}) : stmt; } + Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - return is_no_op(op->body) ? op->body : stmt; - } Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) { return GetRef(op); diff --git a/src/tir/transforms/vectorize_loop.cc b/src/tir/transforms/vectorize_loop.cc index 7ae226c100c0..c23ce2828ce5 100644 --- a/src/tir/transforms/vectorize_loop.cc +++ b/src/tir/transforms/vectorize_loop.cc @@ -867,10 +867,6 @@ class Vectorizer : public StmtMutator, public ExprFunctordtype, 0), var_lanes_, ForKind::kSerial, stmt); } - // ProducerStore - Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc"; - } private: // analyzer diff --git a/tests/python/tir-base/test_tir_constructor.py b/tests/python/tir-base/test_tir_constructor.py index 2df644d7e198..42c2998e27a8 100644 --- a/tests/python/tir-base/test_tir_constructor.py +++ b/tests/python/tir-base/test_tir_constructor.py @@ -186,10 +186,6 @@ def test_stmt_constructor(): assert x.then_case.value.value == 11 assert x.else_case == nop - b = tvm.tir.decl_buffer((1, 2)) - x = tvm.tir.Prefetch(b, []) - assert isinstance(x, tvm.tir.Prefetch) - def test_float_constructor_requires_float_dtype(): with pytest.raises(tvm.TVMError): diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py index daad7f53140b..31ba6fb164d4 100644 --- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py +++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py @@ -482,21 +482,6 @@ def test_ir_builder_tir_buffer_store_predicate(): assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) -def test_ir_builder_tir_prefetch(): - with IRBuilder() as ib: - buffer_a = T.Buffer((128, 128), "float32") - T.prefetch(buffer_a, []) - - # the prefetch generated by IRBuilder - ir_actual = ib.get() - - # the expected prefetch - ir_expected = tir.Prefetch(buffer_a, []) - - # Check if the generated ir is expected - assert_structural_equal(ir_actual, ir_expected, map_free_vars=True) - - def test_ir_builder_tir_evaluate(): with IRBuilder() as ib: T.evaluate(0) diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py b/tests/python/tvmscript/test_tvmscript_printer_tir.py index e03cb6c9d583..267fae20cab3 100644 --- a/tests/python/tvmscript/test_tvmscript_printer_tir.py +++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py @@ -385,20 +385,6 @@ def test_decl_buffer(): ) -def test_prefetch(): - a = tir.decl_buffer((128, 128), "float16", name="A") - with IRBuilder() as ib: - T.prefetch(a, [Range(0, 64), Range(0, 64)]) - obj = ib.get() - _assert_print( - obj, - """ -A = T.Buffer((128, 128), "float16") -T.prefetch(A, [T.Range(0, 64), T.Range(0, 64)]) -""", - ) - - def test_seq_stmt(): with IRBuilder() as ib: with T.serial(10):