From 87fb4cc585e30dcd4c22d1a46961344d99569251 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 13 Jun 2025 15:45:07 +0800 Subject: [PATCH 1/4] phase out ProduceStore --- include/tvm/tir/stmt.h | 54 --------------------- include/tvm/tir/stmt_functor.h | 4 -- python/tvm/tir/__init__.py | 1 - python/tvm/tir/stmt.py | 36 -------------- src/script/printer/legacy_repr.cc | 15 ------ src/script/printer/tir/buffer.cc | 9 ---- src/tir/ir/stmt.cc | 18 ------- src/tir/ir/stmt_functor.cc | 18 ------- src/tir/ir/tir_visitor_with_path.cc | 5 -- src/tir/ir/tir_visitor_with_path.h | 1 - src/tir/transforms/inject_virtual_thread.cc | 2 - src/tir/transforms/vectorize_loop.cc | 4 -- 12 files changed, 167 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 6d93a3a153ad..4b80e97f416c 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -335,60 +335,6 @@ class BufferRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; -/*! - * \brief Store value into mult-dimensional array that will be read by the consumer - * of the producer. - * - * \note This node only appears in high-level DSLs that are built on top of the TIR. - * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower - * this node before TIR transformations. - * - * \sa DataProducer - */ -class ProducerStoreNode : public StmtNode { - public: - /*! \brief The producer to store the results into. */ - DataProducer producer; - /*! \brief The value to be stored. */ - PrimExpr value; - /*! \brief The index arguments of the function. */ - Array indices; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("producer", &producer); - v->Visit("value", &value); - v->Visit("indices", &indices); - v->Visit("span", &span); - } - - bool SEqualReduce(const ProducerStoreNode* other, SEqualReducer equal) const { - return equal(producer, other->producer) && equal(value, other->value) && - equal(indices, other->indices); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(producer); - hash_reduce(value); - hash_reduce(indices); - } - - static constexpr const char* _type_key = "tir.ProducerStore"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerStoreNode, StmtNode); -}; - -/*! - * \brief Managed reference to ProducerStoreNode. - * \sa ProducerStoreNode - */ -class ProducerStore : public Stmt { - public: - TVM_DLL ProducerStore(DataProducer producer, PrimExpr value, Array indices, - Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(ProducerStore, Stmt, ProducerStoreNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerStoreNode); -}; - /*! * \brief Annotate the bounds where the data produced by the producer * need to be written and read in body. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 141fe710b371..071922b7eb46 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -93,7 +93,6 @@ class StmtFunctor { virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -118,7 +117,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProducerStoreNode); IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); @@ -164,7 +162,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProducerStoreNode* op) override; void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; @@ -265,7 +262,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProducerStoreNode* op) override; Stmt VisitStmt_(const ProducerRealizeNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 24db80cb651a..0ce5d27baf9e 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -32,7 +32,6 @@ from .stmt import ( BufferStore, BufferRealize, - ProducerStore, Allocate, AllocateConst, AttrStmt, diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index a04f80b55e7a..e756dd40aa2f 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -293,42 +293,6 @@ def __init__( ) -@tvm.ffi.register_object("tir.ProducerStore") -class ProducerStore(Stmt): - """ProducerStore node. - - Parameters - ---------- - producer : DataProducer - The data producer. - - value : PrimExpr - The value to be stored. - - indices : list of Expr - The index arguments of the store. - - span : Optional[Span] - The location of the stmt in the source code. - """ - - producer: DataProducer - value: PrimExpr - indices: List[PrimExpr] - span: Optional[Span] - - def __init__( - self, - producer: DataProducer, - value: PrimExpr, - indices: List[PrimExpr], - span: Optional[Span] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ProducerStore, producer, value, indices, span # type: ignore - ) - - @tvm.ffi.register_object("tir.Allocate") class Allocate(Stmt): """Allocate node. diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 5e414e90c262..94648cf1be5b 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -625,21 +625,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << "}\n"; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << op->producer->GetNameHint() << "["; - for (size_t i = 0; i < op->indices.size(); ++i) { - p->Print(op->indices[i]); - if (i < op->indices.size() - 1) (*p) << ", "; - } - (*p) << "]"; - (*p) << " ="; - p->Print(op->value); - (*p) << '\n'; - }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 0427c359049b..8c9a4e33743e 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -338,14 +338,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::ProducerStore store, ObjectPath p, IRDocsifier d) -> Doc { - ExprDoc prefix = IdDoc(store->producer->GetNameHint()); - prefix = prefix[BufferIndices(store->indices, p->Attr("indices"), d)]; - return AssignDoc(prefix, d->AsDoc(store->value, p->Attr("value")), std::nullopt); - }); - TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) .set_dispatch( // "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { @@ -364,7 +356,6 @@ TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ProducerStoreNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); } // namespace printer diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 62baf45bc78e..b0ba86b5cadc 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -205,24 +205,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.While").set_body_typed([](PrimExpr condition, Stmt TVM_REGISTER_NODE_TYPE(WhileNode); -// ProducerStore -ProducerStore::ProducerStore(DataProducer producer, PrimExpr value, Array indices, - Span span) { - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->value = std::move(value); - node->indices = std::move(indices); - node->span = std::move(span); - data_ = std::move(node); -} - -TVM_FFI_REGISTER_GLOBAL("tir.ProducerStore") - .set_body_typed([](DataProducer producer, PrimExpr value, Array indices, Span span) { - return ProducerStore(producer, value, indices, span); - }); - -TVM_REGISTER_NODE_TYPE(ProducerStoreNode); - // Allocate Allocate::Allocate(Var buffer_var, DataType dtype, Array extents, PrimExpr condition, Stmt body, Map annotations, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 85d347172702..306e0d92bac6 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -94,11 +94,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProducerStoreNode* op) { - VisitArray(op->indices, [this](const PrimExpr& e) { this->VisitExpr(e); }); - this->VisitExpr(op->value); -} - void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); @@ -395,19 +390,6 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } } -Stmt StmtMutator::VisitStmt_(const ProducerStoreNode* op) { - Array indices = Internal::Mutate(this, op->indices); - PrimExpr value = this->VisitExpr(op->value); - if (indices.same_as(op->indices) && value.same_as(op->value)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->indices = std::move(indices); - n->value = std::move(value); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { Region bounds = Internal::Mutate(this, op->bounds); Stmt body = this->VisitStmt(op->body); diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 4f5007aedb3f..7ee5ff6615ce 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -266,11 +266,6 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const ProducerStoreNode* op, ObjectPath path) { - Visit(op->indices, path->Attr("indices")); - Visit(op->value, path->Attr("value")); -} - void TIRVisitorWithPath::VisitStmt_(const ProducerRealizeNode* op, ObjectPath path) { Visit(op->bounds, path->Attr("bounds")); Visit(op->body, path->Attr("body")); diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index 61441541da32..daa987463c8d 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -110,7 +110,6 @@ class TIRVisitorWithPath : protected ExprFunctordtype, 0), var_lanes_, ForKind::kSerial, stmt); } - // ProducerStore - Stmt VisitStmt_(const ProducerStoreNode* op) final { - LOG(FATAL) << "ProducerProvide cannot appear in a TIR PrimFunc"; - } private: // analyzer From 95896cdd59b75893ad17c843572a1a6a05dbe4c1 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 13 Jun 2025 15:47:40 +0800 Subject: [PATCH 2/4] phase out ProducerRealize --- include/tvm/tir/stmt.h | 64 ----------------------------- include/tvm/tir/stmt_functor.h | 4 -- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/stmt.py | 52 ----------------------- src/script/printer/legacy_repr.cc | 28 ------------- src/script/printer/tir/buffer.cc | 13 ------ src/tir/ir/stmt.cc | 31 -------------- src/tir/ir/stmt_functor.cc | 24 ----------- src/tir/ir/tir_visitor_with_path.cc | 6 --- src/tir/ir/tir_visitor_with_path.h | 1 - src/tir/transforms/remove_no_op.cc | 6 +-- 11 files changed, 2 insertions(+), 229 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 4b80e97f416c..177eac0557dc 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -335,70 +335,6 @@ class BufferRealize : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRealizeNode); }; -/*! - * \brief Annotate the bounds where the data produced by the producer - * need to be written and read in body. - * We will need to allocate space for the corresponding regions. - * - * \note This node only appears in high-level DSLs that are built on top of the TIR. - * It should not appear in a valid TIR PrimFunc. A high-level DSL needs to lower - * this node before TIR transformations. - * - * \sa DataProducer - */ -class ProducerRealizeNode : public StmtNode { - public: - /*! \brief The producer that produces the data. */ - DataProducer producer; - /*! \brief Bounds to be realized. */ - Region bounds; - /*! \brief Only realize if condition holds. */ - PrimExpr condition; - /*! \brief The body of realization. */ - Stmt body; - /*! \brief The storage scope associated with this realization. */ - String storage_scope; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("producer", &producer); - v->Visit("bounds", &bounds); - v->Visit("condition", &condition); - v->Visit("body", &body); - v->Visit("storage_scope", &storage_scope); - v->Visit("span", &span); - } - - bool SEqualReduce(const ProducerRealizeNode* other, SEqualReducer equal) const { - return equal(producer, other->producer) && equal(bounds, other->bounds) && - equal(condition, other->condition) && equal(body, other->body) && - equal(storage_scope, other->storage_scope); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(producer); - hash_reduce(bounds); - hash_reduce(condition); - hash_reduce(body); - hash_reduce(storage_scope); - } - - static constexpr const char* _type_key = "tir.ProducerRealize"; - TVM_DECLARE_FINAL_OBJECT_INFO(ProducerRealizeNode, StmtNode); -}; - -/*! - * \brief Managed reference to ProducerRealizeNode. - * \sa ProducerRealizeNode - */ -class ProducerRealize : public Stmt { - public: - TVM_DLL ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - String storage_scope = "", Span span = Span()); - - TVM_DEFINE_OBJECT_REF_METHODS(ProducerRealize, Stmt, ProducerRealizeNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(ProducerRealizeNode); -}; - /*! * \brief Allocate a buffer that can be used in body. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 071922b7eb46..a36abce22a7b 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -93,7 +93,6 @@ class StmtFunctor { virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const ProducerRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -117,7 +116,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(ProducerRealizeNode); IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); @@ -162,7 +160,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const ProducerRealizeNode* op) override; void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; @@ -262,7 +259,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const ProducerRealizeNode* op) override; Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 0ce5d27baf9e..df8a4f816396 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -38,7 +38,7 @@ DeclBuffer, ) -from .stmt import ProducerRealize, SeqStmt +from .stmt import SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index e756dd40aa2f..0a217c31472b 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -475,58 +475,6 @@ def __init__( ) -@tvm.ffi.register_object("tir.ProducerRealize") -class ProducerRealize(Stmt): - """ProducerRealize node. - - Parameters - ---------- - producer : DataProducer - The data producer. - - bounds : List[Range] - The bound of realize - - condition : PrimExpr - The realize condition. - - body : Stmt - The realize body - - storage_scope : str - The storage scope associated with this realization - - span : Optional[Span] - The location of the stmt in the source code. - """ - - producer: DataProducer - bounds: List[Range] - condition: PrimExpr - body: Stmt - storage_scope: str - span: Optional[Span] - - def __init__( - self, - producer: DataProducer, - bounds: List[Range], - condition: PrimExpr, - body: Stmt, - storage_scope: str = "", - span: Optional[Span] = None, - ) -> None: - self.__init_handle_by_constructor__( - _ffi_api.ProducerRealize, # type: ignore - producer, - bounds, - condition, - body, - storage_scope, - span, - ) - - @tvm.ffi.register_object("tir.SeqStmt") class SeqStmt(Stmt): """Sequence of statements. diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 94648cf1be5b..8dc1c132dda2 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -667,34 +667,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << op->body; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "producer_realize " << op->producer->GetNameHint() << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - if (!is_one(op->condition)) { - (*p) << " if "; - p->Print(op->condition); - } - (*p) << " {\n"; - - p->indent += 2; - p->Print(op->body); - p->indent -= 2; - - p->PrintIndent(); - (*p) << "}\n"; - }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/buffer.cc b/src/script/printer/tir/buffer.cc index 8c9a4e33743e..f1b75066cdba 100644 --- a/src/script/printer/tir/buffer.cc +++ b/src/script/printer/tir/buffer.cc @@ -338,25 +338,12 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return prefix[BufferIndices(load->indices, p->Attr("indices"), d)]; }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::ProducerRealize stmt, ObjectPath p, IRDocsifier d) -> Doc { - ExprDoc prefix = IdDoc(stmt->producer->GetNameHint()); - prefix = prefix[BufferSlices(stmt->bounds, p->Attr("bounds"), d)]; - prefix = TIR(d, "ProducerRealize") - ->Call({prefix, d->AsDoc(stmt->condition, p->Attr("condition"))}); - With f(d, stmt); - AsDocBody(stmt->body, p->Attr("body"), f->get(), d); - return ScopeDoc(std::nullopt, prefix, (*f)->stmts); - }); - TVM_SCRIPT_REPR(tir::BufferRegionNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferLoadNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferStoreNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::BufferNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::MatchBufferRegionNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::ProducerLoadNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::ProducerRealizeNode, ReprPrintTIR); } // namespace printer } // namespace script diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index b0ba86b5cadc..409777c8f7c5 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -335,37 +335,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt TVM_REGISTER_NODE_TYPE(DeclBufferNode); -// ProducerRealize -ProducerRealize::ProducerRealize(DataProducer producer, Region bounds, PrimExpr condition, - Stmt body, String storage_scope, Span span) { - for (size_t i = 0; i < bounds.size(); ++i) { - ICHECK(bounds[i]->min.defined()); - ICHECK(bounds[i]->extent.defined()); - ICHECK(bounds[i]->min.dtype().is_scalar()); - ICHECK(bounds[i]->extent.dtype().is_scalar()); - } - ICHECK(body.defined()); - ICHECK(condition.defined()); - ICHECK(condition.dtype().is_bool()); - - ObjectPtr node = make_object(); - node->producer = std::move(producer); - node->bounds = std::move(bounds); - node->condition = std::move(condition); - node->body = std::move(body); - node->span = std::move(span); - node->storage_scope = std::move(storage_scope); - data_ = std::move(node); -} - -TVM_FFI_REGISTER_GLOBAL("tir.ProducerRealize") - .set_body_typed([](DataProducer producer, Region bounds, PrimExpr condition, Stmt body, - String storage_scope, Span span) { - return ProducerRealize(producer, bounds, condition, body, storage_scope, span); - }); - -TVM_REGISTER_NODE_TYPE(ProducerRealizeNode); - // Prefetch Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { data_ = make_object(buffer, bounds, span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 306e0d92bac6..098aa726d9a4 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -94,15 +94,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const ProducerRealizeNode* op) { - VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); - this->VisitStmt(op->body); - this->VisitExpr(op->condition); -} - void StmtVisitor::VisitStmt_(const PrefetchNode* op) { VisitArray(op->bounds, [this](const Range& r) { this->VisitExpr(r->min); @@ -390,21 +381,6 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } } -Stmt StmtMutator::VisitStmt_(const ProducerRealizeNode* op) { - Region bounds = Internal::Mutate(this, op->bounds); - Stmt body = this->VisitStmt(op->body); - PrimExpr condition = this->VisitExpr(op->condition); - if (bounds.same_as(op->bounds) && body.same_as(op->body) && condition.same_as(op->condition)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->bounds = std::move(bounds); - n->body = std::move(body); - n->condition = std::move(condition); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { Region bounds = Internal::Mutate(this, op->bounds); if (bounds.same_as(op->bounds)) { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 7ee5ff6615ce..2a5026e4ae94 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -266,12 +266,6 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const ProducerRealizeNode* op, ObjectPath path) { - Visit(op->bounds, path->Attr("bounds")); - Visit(op->body, path->Attr("body")); - Visit(op->condition, path->Attr("condition")); -} - void TIRVisitorWithPath::VisitStmt_(const PrefetchNode* op, ObjectPath path) { Visit(op->bounds, path->Attr("bounds")); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index daa987463c8d..f000d595c0be 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -110,7 +110,6 @@ class TIRVisitorWithPath : protected ExprFunctorbody) ? MakeEvaluate({op->min, op->extent}) : stmt; } + Stmt VisitStmt_(const AllocateNode* op) final { Stmt stmt = StmtMutator::VisitStmt_(op); op = stmt.as(); return is_no_op(op->body) ? MakeEvaluate(op->extents) : stmt; } - Stmt VisitStmt_(const ProducerRealizeNode* op) final { - Stmt stmt = StmtMutator::VisitStmt_(op); - op = stmt.as(); - return is_no_op(op->body) ? op->body : stmt; - } Stmt VisitStmt_(const EvaluateNode* op) final { if (HasSideEffect(op->value)) { return GetRef(op); From 91dab15610ac10bc4f165847ff9b1675b853a632 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Fri, 13 Jun 2025 15:53:47 +0800 Subject: [PATCH 3/4] phase out Prefetch --- include/tvm/script/ir_builder/tir/ir.h | 7 --- include/tvm/tir/stmt.h | 45 ------------------- include/tvm/tir/stmt_functor.h | 4 -- python/tvm/script/ir_builder/tir/ir.py | 16 ------- python/tvm/tir/__init__.py | 2 +- python/tvm/tir/stmt.py | 24 ---------- src/script/ir_builder/tir/ir.cc | 5 --- src/script/printer/legacy_repr.cc | 16 ------- src/script/printer/tir/stmt.cc | 11 ----- src/tir/ir/stmt.cc | 11 ----- src/tir/ir/stmt_functor.cc | 18 -------- src/tir/ir/tir_visitor_with_path.cc | 4 -- src/tir/ir/tir_visitor_with_path.h | 1 - tests/python/tir-base/test_tir_constructor.py | 4 -- .../test_tvmscript_ir_builder_tir.py | 15 ------- .../tvmscript/test_tvmscript_printer_tir.py | 14 ------ 16 files changed, 1 insertion(+), 196 deletions(-) diff --git a/include/tvm/script/ir_builder/tir/ir.h b/include/tvm/script/ir_builder/tir/ir.h index b36f5cd7384d..febdac55d9aa 100644 --- a/include/tvm/script/ir_builder/tir/ir.h +++ b/include/tvm/script/ir_builder/tir/ir.h @@ -417,13 +417,6 @@ Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32)); void BufferStore(Buffer buffer, PrimExpr value, Array indices, Optional predicate); -/*! - * \brief The prefetch hint for a buffer - * \param buffer The buffer to be prefetched. - * \param bounds The bounds to be prefetched. - */ -void Prefetch(Buffer buffer, Array bounds); - /*! * \brief Evaluate the input expression. * \param value The input expression to evaluate. diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 177eac0557dc..6df0382fd90b 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -972,51 +972,6 @@ class While : public Stmt { TVM_DEFINE_OBJECT_REF_COW_METHOD(WhileNode); }; -/*! - * \brief A prefetch hint for a buffer - */ -class PrefetchNode : public StmtNode { - public: - /*! \brief The function to be prefetched. */ - Buffer buffer; - /*! \brief Bounds to be prefetched. */ - Array bounds; - - void VisitAttrs(AttrVisitor* v) { - v->Visit("buffer", &buffer); - v->Visit("bounds", &bounds); - v->Visit("span", &span); - } - - bool SEqualReduce(const PrefetchNode* other, SEqualReducer equal) const { - return equal(buffer, other->buffer) && equal(bounds, other->bounds); - } - - void SHashReduce(SHashReducer hash_reduce) const { - hash_reduce(buffer); - hash_reduce(bounds); - } - - PrefetchNode() = default; - PrefetchNode(Buffer buffer, Array bounds, Span span = Span()) - : StmtNode(span), buffer(buffer), bounds(bounds) {} - - static constexpr const char* _type_key = "tir.Prefetch"; - TVM_DECLARE_FINAL_OBJECT_INFO(PrefetchNode, StmtNode); -}; - -/*! - * \brief Managed reference to PrefetchNode. - * \sa PrefetchNode - */ -class Prefetch : public Stmt { - public: - TVM_DLL explicit Prefetch(Buffer buffer, Array bounds, Span span = Span()); - - TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); - TVM_DEFINE_OBJECT_REF_COW_METHOD(PrefetchNode); -}; - /*! * \brief Representing the region of multi-dimensional buffer access. */ diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index a36abce22a7b..23747a7e936c 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -93,7 +93,6 @@ class StmtFunctor { virtual R VisitStmt_(const BufferStoreNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BufferRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const AssertStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; - virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; @@ -116,7 +115,6 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(AllocateConstNode); IR_STMT_FUNCTOR_DISPATCH(DeclBufferNode); IR_STMT_FUNCTOR_DISPATCH(AssertStmtNode); - IR_STMT_FUNCTOR_DISPATCH(PrefetchNode); IR_STMT_FUNCTOR_DISPATCH(SeqStmtNode); IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); @@ -160,7 +158,6 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const BufferRealizeNode* op) override; void VisitStmt_(const AssertStmtNode* op) override; - void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; void VisitStmt_(const BlockNode* op) override; @@ -259,7 +256,6 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const BufferStoreNode* op) override; Stmt VisitStmt_(const BufferRealizeNode* op) override; Stmt VisitStmt_(const AssertStmtNode* op) override; - Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; Stmt VisitStmt_(const BlockNode* op) override; diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 1aaeaa034724..0db45f521bcf 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1316,21 +1316,6 @@ def buffer_store( ) -def prefetch( - buffer: Buffer, # pylint: disable=redefined-outer-name - bounds: List[ir.Range], -) -> None: - """The prefetch hint for a buffer. - - Parameters - ---------- - buffer : Buffer - The buffer to be prefetched. - bounds : List[Range] - The range to be prefetched. - """ - return _ffi_api.Prefetch(buffer, bounds) # type: ignore[attr-defined] # pylint: disable=no-member - def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. @@ -2144,7 +2129,6 @@ def wrapped(*args, **kwargs): "launch_thread", "env_thread", "buffer_store", - "prefetch", "evaluate", "boolean", "handle", diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index df8a4f816396..5c4a2b91f5d7 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -39,7 +39,7 @@ ) from .stmt import SeqStmt -from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list +from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize from .function import PrimFunc, TensorIntrin, IndexMap diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 0a217c31472b..abdd32829c9f 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -556,30 +556,6 @@ def __init__(self, value: PrimExpr, span: Optional[Span] = None) -> None: self.__init_handle_by_constructor__(_ffi_api.Evaluate, value, span) # type: ignore -@tvm.ffi.register_object("tir.Prefetch") -class Prefetch(Stmt): - """Prefetch node. - - Parameters - ---------- - buffer : Buffer - The buffer to be prefetched. - - bounds : List[Range] - The bounds to be prefetched. - - span : Optional[Span] - The location of the stmt in the source code. - """ - - buffer: Buffer - bounds: List[Range] - span: Optional[Span] - - def __init__(self, buffer: Buffer, bounds: List[Range], span: Optional[Span] = None) -> None: - self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) # type: ignore - - @tvm.ffi.register_object("tir.BufferRegion") class BufferRegion(Object, Scriptable): """BufferRegion node. diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc index 2d61ca3e75f5..6f73254ff2ab 100644 --- a/src/script/ir_builder/tir/ir.cc +++ b/src/script/ir_builder/tir/ir.cc @@ -596,10 +596,6 @@ void BufferStore(Buffer buffer, PrimExpr value, Array indices, AddToParent(tvm::tir::BufferStore(buffer, value, indices, predicate)); } -void Prefetch(Buffer buffer, Array bounds) { - AddToParent(tvm::tir::Prefetch(buffer, bounds)); -} - DeclBufferFrame DeclBuffer(Array shape, DataType dtype, String buffer_name, Optional data, Optional> strides, Optional elem_offset, String storage_scope, int align, @@ -724,7 +720,6 @@ TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.LaunchThread") TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.EnvThread").set_body_typed(EnvThread); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.BufferStore").set_body_typed(BufferStore); -TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Prefetch").set_body_typed(Prefetch); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Evaluate").set_body_typed(Evaluate); TVM_FFI_REGISTER_GLOBAL("script.ir_builder.tir.Ptr").set_body_typed(Ptr); diff --git a/src/script/printer/legacy_repr.cc b/src/script/printer/legacy_repr.cc index 8dc1c132dda2..57dd691b8897 100644 --- a/src/script/printer/legacy_repr.cc +++ b/src/script/printer/legacy_repr.cc @@ -667,22 +667,6 @@ TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) (*p) << op->body; }); -TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) - .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { - auto* op = static_cast(node.get()); - p->PrintIndent(); - (*p) << "prefetch " << op->buffer << "("; - for (size_t i = 0; i < op->bounds.size(); ++i) { - (*p) << "["; - p->Print(op->bounds[i]->min); - (*p) << ", "; - p->Print(op->bounds[i]->extent); - (*p) << "]"; - if (i < op->bounds.size() - 1) (*p) << ", "; - } - (*p) << ")"; - }); - TVM_STATIC_IR_FUNCTOR(ReprLegacyPrinter, vtable) .set_dispatch([](const ObjectRef& node, ReprLegacyPrinter* p) { auto* op = static_cast(node.get()); diff --git a/src/script/printer/tir/stmt.cc b/src/script/printer/tir/stmt.cc index 1d310c2a5a9f..239d9f67216f 100644 --- a/src/script/printer/tir/stmt.cc +++ b/src/script/printer/tir/stmt.cc @@ -199,16 +199,6 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) return StmtBlockDoc((*f)->stmts); }); -TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable) - .set_dispatch( // - "", [](tir::Prefetch stmt, ObjectPath p, IRDocsifier d) -> Doc { - return ExprStmtDoc(TIR(d, "prefetch") - ->Call({ - d->AsDoc(stmt->buffer, p->Attr("buffer")), - d->AsDoc(stmt->bounds, p->Attr("bounds")), - })); - }); - bool IsAllocateDeclBufferPattern(const tir::AllocateNode* allocate) { const tir::Var& buffer_var = allocate->buffer_var; if (const tir::DeclBufferNode* decl_buffer = allocate->body.as()) { @@ -462,7 +452,6 @@ TVM_SCRIPT_REPR(tir::WhileNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AllocateNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::AllocateConstNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::DeclBufferNode, ReprPrintTIR); -TVM_SCRIPT_REPR(tir::PrefetchNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::SeqStmtNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::IfThenElseNode, ReprPrintTIR); TVM_SCRIPT_REPR(tir::EvaluateNode, ReprPrintTIR); diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 409777c8f7c5..267fe13d1ff2 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -335,17 +335,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt TVM_REGISTER_NODE_TYPE(DeclBufferNode); -// Prefetch -Prefetch::Prefetch(Buffer buffer, Array bounds, Span span) { - data_ = make_object(buffer, bounds, span); -} - -TVM_FFI_REGISTER_GLOBAL("tir.Prefetch") - .set_body_typed([](Buffer buffer, Array bounds, Span span) { - return Prefetch(buffer, bounds, span); - }); - -TVM_REGISTER_NODE_TYPE(PrefetchNode); // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index 098aa726d9a4..4daa0e9a5468 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -94,13 +94,6 @@ void StmtVisitor::VisitStmt_(const AssertStmtNode* op) { this->VisitStmt(op->body); } -void StmtVisitor::VisitStmt_(const PrefetchNode* op) { - VisitArray(op->bounds, [this](const Range& r) { - this->VisitExpr(r->min); - this->VisitExpr(r->extent); - }); -} - void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { VisitArray(op->seq, [this](const Stmt& s) { this->VisitStmt(s); }); } @@ -381,17 +374,6 @@ Stmt StmtMutator::VisitStmt_(const BufferRealizeNode* op) { } } -Stmt StmtMutator::VisitStmt_(const PrefetchNode* op) { - Region bounds = Internal::Mutate(this, op->bounds); - if (bounds.same_as(op->bounds)) { - return GetRef(op); - } else { - auto n = CopyOnWrite(op); - n->bounds = std::move(bounds); - return Stmt(n); - } -} - Stmt StmtMutator::VisitStmt_(const SeqStmtNode* op) { Array seq = Internal::Mutate(this, op->seq); if (seq.same_as(op->seq)) { diff --git a/src/tir/ir/tir_visitor_with_path.cc b/src/tir/ir/tir_visitor_with_path.cc index 2a5026e4ae94..78cfd004dd4d 100644 --- a/src/tir/ir/tir_visitor_with_path.cc +++ b/src/tir/ir/tir_visitor_with_path.cc @@ -266,10 +266,6 @@ void TIRVisitorWithPath::VisitStmt_(const AssertStmtNode* op, ObjectPath path) { Visit(op->body, path->Attr("body")); } -void TIRVisitorWithPath::VisitStmt_(const PrefetchNode* op, ObjectPath path) { - Visit(op->bounds, path->Attr("bounds")); -} - void TIRVisitorWithPath::VisitStmt_(const SeqStmtNode* op, ObjectPath path) { Visit(op->seq, path->Attr("seq")); } diff --git a/src/tir/ir/tir_visitor_with_path.h b/src/tir/ir/tir_visitor_with_path.h index f000d595c0be..6b1cd8ace487 100644 --- a/src/tir/ir/tir_visitor_with_path.h +++ b/src/tir/ir/tir_visitor_with_path.h @@ -110,7 +110,6 @@ class TIRVisitorWithPath : protected ExprFunctor Date: Sat, 14 Jun 2025 19:01:32 +0800 Subject: [PATCH 4/4] lint --- include/tvm/tir/stmt.h | 1 - python/tvm/script/ir_builder/tir/ir.py | 1 - python/tvm/tir/stmt.py | 2 +- src/tir/ir/stmt.cc | 1 - 4 files changed, 1 insertion(+), 4 deletions(-) diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 6df0382fd90b..cb5db7e44f8a 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -29,7 +29,6 @@ #include #include #include -#include namespace tvm { namespace tir { diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index 0db45f521bcf..c7589f4a19a6 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1316,7 +1316,6 @@ def buffer_store( ) - def evaluate(value: PrimExpr) -> None: """Evaluate the input expression. diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index abdd32829c9f..ffb6fd6a7068 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -34,7 +34,7 @@ from tvm.runtime import Object, Scriptable, const, NDArray from . import _ffi_api -from .buffer import Buffer, DataProducer +from .buffer import Buffer from .expr import Var, IterVar diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 267fe13d1ff2..f400ca0d507e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -335,7 +335,6 @@ TVM_FFI_REGISTER_GLOBAL("tir.DeclBuffer").set_body_typed([](Buffer buffer, Stmt TVM_REGISTER_NODE_TYPE(DeclBufferNode); - // SeqStmt SeqStmt::SeqStmt(Array seq, Span span) { bool requires_flattening = std::any_of(