diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 093d49ca2dd4..074bcdd3f533 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -862,7 +862,7 @@ class For : public Stmt { }; /*! - * \brief A prefetch hint for abuffer + * \brief A prefetch hint for a buffer */ class PrefetchNode : public StmtNode { public: @@ -905,6 +905,252 @@ class Prefetch : public Stmt { TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode); }; +/*! + * \brief Representing the region of multi-dimensional buffer access. + */ +class BufferRegionNode : public Object { + public: + /*! \brief The buffer of the buffer region. */ + Buffer buffer; + /*! \brief The region array of the buffer region. */ + Array region; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("region", ®ion); + } + + bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(region, other->region); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(region); + } + + static constexpr const char* _type_key = "tir.BufferRegion"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object); +}; + +/*! + * \brief Managed reference to BufferRegionNode. + * \sa BufferRegionNode + */ +class BufferRegion : public ObjectRef { + public: + TVM_DLL explicit BufferRegion(Buffer buffer, Array region); + + /*! + * \brief Create a BufferRegion which is full region of the given buffer.. + * \param buffer The buffer to generate full BufferRegion. + * \return The BufferRegion which covers all region of the given buffer + */ + TVM_DLL static BufferRegion FullRegion(Buffer buffer); + + TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode); +}; + +/*! + * \brief Match introduces a constraint that the source buffer region can be remapped to the data + * layout specified by the buffer field. The constraint can be checked in later part of lowering (or + * optionally during runtime). + * + * MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in + * low-level hardware primitives in the IR and defer the check after the sequence of + * transformations. + */ +class MatchBufferRegionNode : public Object { + public: + /*! \brief The target buffer. */ + Buffer buffer; + /*! \brief The source buffer region. */ + BufferRegion source; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("buffer", &buffer); + v->Visit("source", &source); + } + + bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const { + return equal(buffer, other->buffer) && equal(source, other->source); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(buffer); + hash_reduce(source); + } + + static constexpr const char* _type_key = "tir.MatchBufferRegion"; + static constexpr const bool _type_has_method_sequal_reduce = true; + static constexpr const bool _type_has_method_shash_reduce = true; + TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object); +}; + +/*! + * \brief Managed reference to MatchBufferRegionNode. + * \sa MatchBufferRegionNode + */ +class MatchBufferRegion : public ObjectRef { + public: + TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source); + + TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode); +}; + +/*! + * \brief A block is a basic schedule unit in TIR. + * \note Block's body is parameterized by iter vars. + * \code + * + * with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]: + * tir.bind(v0, value0) + * tir.bind(v1, value1) + * ... + * tir.reads([buffer0[start:end, ...], ...]) + * tir.writes([buffer1[start:end, ...], ...]) + * tir.where(predicate) + * buffer2 = tir.alloc_buffer(shape, dtype) + * buffer3 = tir.match_buffer(source_buffer[start:end, ...]) + * tir.attr({attr_key: attr_value, ...}) + * with tir.init(): + * // init body + * // body + * + * \endcode + */ +class BlockNode : public StmtNode { + public: + /*! \brief The variables of the block. */ + Array iter_vars; + /*! \brief The read buffer regions of the block. */ + Array reads; + /*! \brief The write buffer regions of the block. */ + Array writes; + /*! \brief The name_hint of the block. */ + String name_hint; + /*! \brief The body of the block. */ + Stmt body; + /*! + * \brief The init statement is executed during the first iteration of reduction loops in a + * reduction block. The optional init field allows us to represent initialization and + * reduction update in a single block and transform them collectively. + * We also provide primitives to decompose the init into a separate block during scheduling. + * Init field is `NullOpt` if there is no reduction iter_vars + */ + Optional init; + /*! \brief The buffer allocated in the block. */ + Array alloc_buffers; + /*! \brief The match buffer regions. */ + Array match_buffers; + /*! \brief The annotation of the block. */ + Map annotations; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("iter_vars", &iter_vars); + v->Visit("reads", &reads); + v->Visit("writes", &writes); + v->Visit("name_hint", &name_hint); + v->Visit("body", &body); + v->Visit("init", &init); + v->Visit("alloc_buffers", &alloc_buffers); + v->Visit("match_buffers", &match_buffers); + v->Visit("annotations", &annotations); + } + + bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const { + // Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars + return equal.DefEqual(iter_vars, other->iter_vars) && + equal(alloc_buffers, other->alloc_buffers) && + equal(match_buffers, other->match_buffers) && equal(reads, other->reads) && + equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) && + equal(annotations, other->annotations); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce.DefHash(iter_vars); + hash_reduce(alloc_buffers); + hash_reduce(match_buffers); + hash_reduce(reads); + hash_reduce(writes); + hash_reduce(body); + hash_reduce(init); + hash_reduce(annotations); + } + + static constexpr const char* _type_key = "tir.Block"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode); +}; + +/*! + * \brief Managed reference to BlockNode. + * \sa BlockNode + */ +class Block : public Stmt { + public: + TVM_DLL explicit Block(Array iter_vars, Array reads, + Array writes, String name_hint, Stmt body, + Optional init = NullOpt, + Array alloc_buffers = Array(), + Array match_buffers = Array(), + Map annotations = Map(), + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode); +}; + +/*! + * \brief A block realization node represents execution of the block at the binding values. + */ +class BlockRealizeNode : public StmtNode { + public: + /*! \brief The corresponding values of the iter vars. */ + Array iter_values; + /*! + * \brief The predicate of the block realization, the block will only be executed when the + * predicate is true. + */ + PrimExpr predicate; + /*! \brief The block to be realized. */ + Block block; + + void VisitAttrs(AttrVisitor* v) { + v->Visit("iter_values", &iter_values); + v->Visit("predicate", &predicate); + v->Visit("block", &block); + } + + bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const { + return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) && + equal(block, other->block); + } + + void SHashReduce(SHashReducer hash_reduce) const { + hash_reduce(iter_values); + hash_reduce(predicate); + hash_reduce(block); + } + + static constexpr const char* _type_key = "tir.BlockRealize"; + TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode); +}; + +/*! + * \brief Managed reference to BlockRealizeNode + * \sa BlockRealizeNode + */ +class BlockRealize : public Stmt { + public: + TVM_DLL explicit BlockRealize(Array iter_values, PrimExpr predicate, Block block, + Span span = Span()); + + TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode); +}; + /*! \brief namespace of possible attribute sin AttrStmt.attr_key */ namespace attr { // The above attr does not pass to ir stage. diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 0f4238deeebd..e53b02d73e1d 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -96,6 +96,8 @@ class StmtFunctor { virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT; + virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT; virtual R VisitStmtDefault_(const Object* op, Args...) { LOG(FATAL) << "Do not have a default for " << op->GetTypeKey(); return R(); @@ -119,6 +121,8 @@ class StmtFunctor { IR_STMT_FUNCTOR_DISPATCH(EvaluateNode); IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode); IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode); + IR_STMT_FUNCTOR_DISPATCH(BlockNode); + IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode); return vtable; } }; @@ -158,6 +162,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor { void VisitStmt_(const PrefetchNode* op) override; void VisitStmt_(const SeqStmtNode* op) override; void VisitStmt_(const EvaluateNode* op) override; + void VisitStmt_(const BlockNode* op) override; + void VisitStmt_(const BlockRealizeNode* op) override; }; /*! @@ -249,6 +255,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor { Stmt VisitStmt_(const PrefetchNode* op) override; Stmt VisitStmt_(const SeqStmtNode* op) override; Stmt VisitStmt_(const EvaluateNode* op) override; + Stmt VisitStmt_(const BlockNode* op) override; + Stmt VisitStmt_(const BlockRealizeNode* op) override; /*! * \brief Alternative advance method for SeqStmtNode. * diff --git a/python/tvm/tir/__init__.py b/python/tvm/tir/__init__.py index 324c4daf19ba..ad91eab64b52 100644 --- a/python/tvm/tir/__init__.py +++ b/python/tvm/tir/__init__.py @@ -31,6 +31,7 @@ from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt from .stmt import ProducerRealize, SeqStmt from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list +from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize from .function import PrimFunc diff --git a/python/tvm/tir/stmt.py b/python/tvm/tir/stmt.py index 5882dca5578e..e4f1ac924a83 100644 --- a/python/tvm/tir/stmt.py +++ b/python/tvm/tir/stmt.py @@ -26,11 +26,15 @@ assert isinstance(st, tvm.tir.stmt.Store) assert(st.buffer_var == a) """ +from typing import List, Optional, Mapping from enum import IntEnum import tvm._ffi from tvm.runtime import Object +from tvm.ir import Span, PrimExpr, Range from . import _ffi_api +from .buffer import Buffer +from .expr import IterVar class Stmt(Object): @@ -429,6 +433,164 @@ def __init__(self, buffer, bounds, span=None): self.__init_handle_by_constructor__(_ffi_api.Prefetch, buffer, bounds, span) +@tvm._ffi.register_object("tir.BufferRegion") +class BufferRegion(Object): + """BufferRegion node. + + Parameters + ---------- + buffer : Buffer + The buffer of the buffer region + + region : List[Range] + The region array of the buffer region + """ + + buffer: Buffer + region: List[Range] + + def __init__(self, buffer: Buffer, region: List[Range]): + self.__init_handle_by_constructor__(_ffi_api.BufferRegion, buffer, region) + + +@tvm._ffi.register_object("tir.MatchBufferRegion") +class MatchBufferRegion(Object): + """MatchBufferRegion node. + + Parameters + ---------- + buffer : Buffer + The target buffer + + source : BufferRegion + The region of source buffer + """ + + buffer: Buffer + source: BufferRegion + + def __init__(self, buffer: Buffer, source: BufferRegion): + self.__init_handle_by_constructor__(_ffi_api.MatchBufferRegion, buffer, source) + + +@tvm._ffi.register_object("tir.Block") +class Block(Stmt): + """Block node. + + Parameters + ---------- + iter_vars : List[IterVar] + The block Variable. + + reads : List[BufferRegion] + The read buffer regions of the block. + + writes: List[BufferRegion] + The write buffer regions of the block. + + name_hint: str + the name_hint of the block. + + body: Stmt + The body of the block. + + init: Optional[Stmt] + The init block of the reduction block + + alloc_buffers: Optional[list[Buffer]] + The buffer allocations + + match_buffers: Optional[List[MatchBufferRegion]] + The subregion buffer match + + annotations: Optional[Mapping[str, Object]] + Additional annotation hints. + + span : Optional[Span] + The location of this block in the source code. + """ + + iter_vars: List[IterVar] + reads: List[BufferRegion] + writes: List[BufferRegion] + name_hint: str + body: Stmt + init: Optional[Stmt] + alloc_buffers: Optional[List[Buffer]] + match_buffers: Optional[List[MatchBufferRegion]] + annotations: Optional[Mapping[str, Object]] + span: Optional[Span] + + def __init__( + self, + iter_vars: List[IterVar], + reads: List[BufferRegion], + writes: List[BufferRegion], + name_hint: str, + body: Stmt, + init: Optional[Stmt] = None, + alloc_buffers: Optional[List[Buffer]] = None, + match_buffers: Optional[List[MatchBufferRegion]] = None, + annotations: Optional[Mapping[str, Object]] = None, + span: Optional[Span] = None, + ): + if alloc_buffers is None: + alloc_buffers = [] + if match_buffers is None: + match_buffers = [] + if annotations is None: + annotations = {} + self.__init_handle_by_constructor__( + _ffi_api.Block, + iter_vars, + reads, + writes, + name_hint, + body, + init, + alloc_buffers, + match_buffers, + annotations, + span, + ) + + +@tvm._ffi.register_object("tir.BlockRealize") +class BlockRealize(Stmt): + """BlockRealize node. + + Parameters + ---------- + iter_values : List[PrimExpr] + The binding values of the block var. + + predicate : PrimExpr + The predicate of the block. + + block : Block + The block to realize + + span : Optional[Span] + The location of this block_realize in the source code. + """ + + iter_values: List[PrimExpr] + predicate: PrimExpr + block: Block + span: Optional[Span] + + def __init__( + self, + iter_values: List[PrimExpr], + predicate: PrimExpr, + block: Block, + span: Optional[Span] = None, + ): + self.__init_handle_by_constructor__( + _ffi_api.BlockRealize, iter_values, predicate, block, span + ) + + def stmt_seq(*args): """Make sequence of statements diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 92dc38797544..e54be4347c8e 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -598,6 +598,225 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) p->stream << "}\n"; }); +// BufferRegion +BufferRegion::BufferRegion(Buffer buffer, Array region) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->region = std::move(region); + data_ = std::move(node); +} + +BufferRegion BufferRegion::FullRegion(Buffer buffer) { + Array region; + for (PrimExpr extent : buffer->shape) { + region.push_back(Range::FromMinExtent(0, extent)); + } + return BufferRegion(buffer, region); +} + +TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { + return BufferRegion(buffer, region); +}); + +TVM_REGISTER_NODE_TYPE(BufferRegionNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->stream << op->buffer->name; + p->stream << "["; + for (size_t i = 0; i < op->region.size(); ++i) { + const auto& range = op->region[i]; + p->Print(range->min); + if (!is_one(range->extent)) { + p->stream << ":"; + p->Print(range->min + range->extent); + } + if (i != op->region.size() - 1) p->stream << ", "; + } + p->stream << "]"; + }); + +// MatchBufferRegion +MatchBufferRegion::MatchBufferRegion(Buffer buffer, BufferRegion source) { + ObjectPtr node = make_object(); + node->buffer = std::move(buffer); + node->source = std::move(source); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.MatchBufferRegion").set_body_typed([](Buffer buffer, BufferRegion source) { + return MatchBufferRegion(buffer, source); +}); + +TVM_REGISTER_NODE_TYPE(MatchBufferRegionNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + p->stream << op->buffer->name << " = match_buffer_region("; + p->Print(op->source); + p->stream << ")\n"; + }); + +// Block +Block::Block(Array iter_vars, Array reads, Array writes, + String name_hint, Stmt body, Optional init, Array alloc_buffers, + Array match_buffers, Map annotations, + Span span) { + ObjectPtr node = make_object(); + node->iter_vars = std::move(iter_vars); + node->reads = std::move(reads); + node->writes = std::move(writes); + node->name_hint = std::move(name_hint); + node->body = std::move(body); + node->init = std::move(init); + node->alloc_buffers = std::move(alloc_buffers); + node->match_buffers = std::move(match_buffers); + node->annotations = std::move(annotations); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.Block") + .set_body_typed([](Array iter_vars, Array reads, + Array writes, String name_hint, Stmt body, Optional init, + Array alloc_buffers, Array match_buffers, + Map annotations, Span span) { + return Block(iter_vars, reads, writes, name_hint, body, init, alloc_buffers, match_buffers, + annotations, span); + }); + +TVM_REGISTER_NODE_TYPE(BlockNode); + +void PrintBlockTitle(const BlockNode* op, ReprPrinter* p) { + p->stream << "block " << op->name_hint << "("; + for (size_t i = 0; i < op->iter_vars.size(); i++) { + p->Print(op->iter_vars[i]); + if (i < op->iter_vars.size() - 1) p->stream << ", "; + } + p->stream << ")"; +} + +void PrintBlockSignature(const BlockNode* op, ReprPrinter* p) { + // print read/write regions + p->PrintIndent(); + p->stream << "reads("; + p->Print(op->reads); + p->stream << ")\n"; + p->PrintIndent(); + p->stream << "writes("; + p->Print(op->writes); + p->stream << ")\n"; + // Print alloc_buffers + for (const auto& alloc_buf : op->alloc_buffers) { + p->PrintIndent(); + p->stream << alloc_buf->name << " = alloc_buffer(" << alloc_buf->dtype << "["; + for (size_t i = 0; i < alloc_buf->shape.size(); ++i) { + if (i > 0) p->stream << ", "; + p->Print(alloc_buf->shape[i]); + } + p->stream << "])\n"; + } + // Print match_buffer_regions + for (const auto& match_buf : op->match_buffers) { + p->Print(match_buf); + } + if (!op->annotations.empty()) { + p->PrintIndent(); + p->stream << "annotations(" << op->annotations << ")\n"; + } +} + +void PrintBlockBody(const BlockNode* op, ReprPrinter* p) { + // Print init + if (op->init.defined()) { + p->PrintIndent(); + p->stream << "with init() {\n"; + p->indent += 2; + p->Print(op->init.value()); + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + } + // Print body + p->Print(op->body); +} + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + p->PrintIndent(); + PrintBlockTitle(op, p); + p->stream << "{\n"; + p->indent += 2; + + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(op, p); + // Print block init and body + PrintBlockBody(op, p); + + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + +// BlockRealize +BlockRealize::BlockRealize(Array values, PrimExpr predicate, Block block, Span span) { + CHECK_EQ(block->iter_vars.size(), values.size()) + << "ValueError: BlockRealize needs to have the same number of iter_vars and binding values"; + CHECK(predicate.dtype().is_bool()) << "TypeError: Expect Block.predicate to be a bool expression"; + ObjectPtr node = make_object(); + node->iter_values = std::move(values); + node->predicate = std::move(predicate); + node->block = std::move(block); + node->span = std::move(span); + data_ = std::move(node); +} + +TVM_REGISTER_GLOBAL("tir.BlockRealize") + .set_body_typed([](Array iter_values, PrimExpr predicate, Block block, Span span) { + return BlockRealize(iter_values, predicate, block, span); + }); + +TVM_REGISTER_NODE_TYPE(BlockRealizeNode); + +TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) + .set_dispatch([](const ObjectRef& node, ReprPrinter* p) { + auto* op = static_cast(node.get()); + auto* block_op = op->block.get(); + p->PrintIndent(); + PrintBlockTitle(block_op, p); + p->stream << "{\n"; + p->indent += 2; + + // Print binding iter_values + for (size_t i = 0; i < block_op->iter_vars.size(); ++i) { + p->PrintIndent(); + p->stream << "bind("; + p->Print(block_op->iter_vars[i]->var); + p->stream << ", "; + p->Print(op->iter_values[i]); + p->stream << ")\n"; + } + // Print predicate + if (!is_one(op->predicate)) { + p->PrintIndent(); + p->stream << "where("; + p->Print(op->predicate); + p->stream << ")\n"; + } + // Print block elements (e.g. reads/writes, etc) + PrintBlockSignature(block_op, p); + // Print block init and body + PrintBlockBody(block_op, p); + + p->indent -= 2; + p->PrintIndent(); + p->stream << "}\n"; + }); + PrimExpr TypeAnnotation(DataType dtype, Span span) { static auto op = Op::Get("tir.type_annotation"); return tir::Call(dtype, op, {}, span); diff --git a/src/tir/ir/stmt_functor.cc b/src/tir/ir/stmt_functor.cc index e4cc1b7e4275..f05dc7116494 100644 --- a/src/tir/ir/stmt_functor.cc +++ b/src/tir/ir/stmt_functor.cc @@ -112,6 +112,35 @@ void StmtVisitor::VisitStmt_(const SeqStmtNode* op) { void StmtVisitor::VisitStmt_(const EvaluateNode* op) { this->VisitExpr(op->value); } +void StmtVisitor::VisitStmt_(const BlockNode* op) { + auto fvisit_buffer_region = [this](const BufferRegion& s) { + for (const auto& range : s->region) { + this->VisitExpr(range->min); + this->VisitExpr(range->extent); + } + }; + VisitArray(op->iter_vars, [this](const IterVar& iter_var) { + this->VisitExpr(iter_var->dom->min); + this->VisitExpr(iter_var->dom->extent); + }); + VisitArray(op->reads, fvisit_buffer_region); + VisitArray(op->writes, fvisit_buffer_region); + VisitArray(op->match_buffers, + [fvisit_buffer_region](const MatchBufferRegion& match_buffer_region) { + fvisit_buffer_region(match_buffer_region->source); + }); + if (op->init.defined()) { + this->VisitStmt(op->init.value()); + } + this->VisitStmt(op->body); +} + +void StmtVisitor::VisitStmt_(const BlockRealizeNode* op) { + VisitArray(op->iter_values, [this](const PrimExpr& e) { this->VisitExpr(e); }); + this->VisitExpr(op->predicate); + this->VisitStmt(op->block); +} + class StmtMutator::Internal { public: /*! @@ -150,6 +179,20 @@ class StmtMutator::Internal { } } + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const IterVar& iter_var) { + PrimExpr min = self->VisitExpr(iter_var->dom->min); + PrimExpr extent = self->VisitExpr(iter_var->dom->extent); + if (min.same_as(iter_var->dom->min) && extent.same_as(iter_var->dom->extent)) { + return iter_var; + } else { + return IterVar(Range(min, extent), iter_var->var, iter_var->iter_type, + iter_var->thread_tag); + } + }; + return MutateArray(self, arr, fmutate); + } + static Array Mutate(StmtMutator* self, const Array& arr) { auto fmutate = [self](const PrimExpr& e) { return self->VisitExpr(e); }; return MutateArray(self, arr, fmutate); @@ -172,6 +215,31 @@ class StmtMutator::Internal { }; return MutateArray(self, arr, fmutate); } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const BufferRegion& buffer_region) { + Array region = Mutate(self, buffer_region->region); + if (region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(buffer_region->buffer, region); + } + }; + return MutateArray(self, arr, fmutate); + } + + static Array Mutate(StmtMutator* self, const Array& arr) { + auto fmutate = [self](const MatchBufferRegion& match_buffer_region) { + Array region = Mutate(self, match_buffer_region->source->region); + if (region.same_as(match_buffer_region->source->region)) { + return match_buffer_region; + } else { + return MatchBufferRegion(match_buffer_region->buffer, + BufferRegion(match_buffer_region->source->buffer, region)); + } + }; + return MutateArray(self, arr, fmutate); + } }; Stmt StmtMutator::VisitStmt_(const AttrStmtNode* op) { @@ -415,6 +483,47 @@ Stmt StmtMutator::VisitStmt_(const EvaluateNode* op) { } } +Stmt StmtMutator::VisitStmt_(const BlockNode* op) { + Array iter_vars = Internal::Mutate(this, op->iter_vars); + Array reads = Internal::Mutate(this, op->reads); + Array writes = Internal::Mutate(this, op->writes); + Array match_buffers = Internal::Mutate(this, op->match_buffers); + Optional init = NullOpt; + if (op->init.defined()) { + init = VisitStmt(op->init.value()); + } + Stmt body = VisitStmt(op->body); + if (iter_vars.same_as(op->iter_vars) && reads.same_as(op->reads) && writes.same_as(op->writes) && + body.same_as(op->body) && init.same_as(op->init) && + match_buffers.same_as(op->match_buffers)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_vars = std::move(iter_vars); + n->reads = std::move(reads); + n->writes = std::move(writes); + n->body = std::move(body); + n->init = std::move(init); + n->match_buffers = std::move(match_buffers); + return Stmt(n); + } +} + +Stmt StmtMutator::VisitStmt_(const BlockRealizeNode* op) { + Array v = Internal::Mutate(this, op->iter_values); + PrimExpr pred = this->VisitExpr(op->predicate); + Stmt block = this->VisitStmt(op->block); + if (v.same_as(op->iter_values) && pred.same_as(op->predicate) && block.same_as(op->block)) { + return GetRef(op); + } else { + auto n = CopyOnWrite(op); + n->iter_values = std::move(v); + n->predicate = std::move(pred); + n->block = Downcast(block); + return Stmt(n); + } +} + // Implementations of IRTransform, PostOrderVisit and Substitute class IRApplyVisit : public StmtExprVisitor { public: diff --git a/tests/cpp/ir_functor_test.cc b/tests/cpp/ir_functor_test.cc index d242b20f1ba7..237dc46b99ca 100644 --- a/tests/cpp/ir_functor_test.cc +++ b/tests/cpp/ir_functor_test.cc @@ -120,6 +120,25 @@ TEST(IRF, StmtVisitor) { }; v(fmaketest()); ICHECK_EQ(v.count, 3); + + { + // tests for block and block_realize + Stmt body = fmaketest(); + DataType dtype = DataType::Float(32); + Var buf_var("b", PointerType(PrimType(dtype))); + Buffer buffer = decl_buffer({16}); + BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); + MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); + + // construct block and block_realize + Block block = + Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region}); + Stmt block_realize = BlockRealize({}, const_true(), block); + + v.count = 0; + v(block_realize); + ICHECK_EQ(v.count, 9); + } } TEST(IRF, StmtMutator) { @@ -229,6 +248,28 @@ TEST(IRF, StmtMutator) { // the seq get flattened ICHECK(body.as()->seq[0].as()->extents.get() != extentptr); } + + { + // tests for block and block_realize + Stmt body = fmakealloc(); + DataType dtype = DataType::Float(32); + Var buf_var("b", PointerType(PrimType(dtype))); + Buffer buffer = decl_buffer({16}); + BufferRegion buffer_region(buffer, {Range::FromMinExtent(x + 1, 1)}); + MatchBufferRegion match_buffer_region(decl_buffer({1}), buffer_region); + // construct block and block_realize + Block block = + Block({}, {buffer_region}, {buffer_region}, "block", body, body, {}, {match_buffer_region}); + Stmt block_realize = BlockRealize({}, const_true(), block); + body = v(std::move(block_realize)); + // the body should be changed + Block new_block = body.as()->block; + ICHECK(new_block->body.as()->extents[1].same_as(x)); + ICHECK(new_block->init.as()->extents[1].same_as(x)); + ICHECK(new_block->reads[0]->region[0]->min.same_as(x)); + ICHECK(new_block->writes[0]->region[0]->min.same_as(x)); + ICHECK(new_block->match_buffers[0]->source->region[0]->min.same_as(x)); + } } int main(int argc, char** argv) { diff --git a/tests/python/unittest/test_tir_nodes.py b/tests/python/unittest/test_tir_nodes.py index bff60f70f53b..6e338d64a61c 100644 --- a/tests/python/unittest/test_tir_nodes.py +++ b/tests/python/unittest/test_tir_nodes.py @@ -364,6 +364,87 @@ def test_intimm_cond(): assert x == 1 +def test_block_blockrealize(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + vx = tvm.tir.IterVar((16, 16), "vx", 0) + vx_var = vx.var + vy = tvm.tir.IterVar((16, 16), "vy", 2) + vy_var = vy.var + A = tvm.tir.decl_buffer((16), "float32") + B = tvm.tir.decl_buffer((16, 16), "float32") + alloc_buffer = tvm.tir.decl_buffer((16, 16), "float32") + match_buffer = tvm.tir.decl_buffer((16, 16), "float32") + init_body = tvm.tir.BufferStore(A, 0.0, [vx_var]) + body = tvm.tir.BufferStore( + A, + tvm.tir.BufferLoad(A, [vx_var]) + tvm.tir.BufferLoad(B, [vx_var, vy_var]), + [vx_var], + ) + reads = [ + tvm.tir.BufferRegion( + B, [tvm.ir.Range.from_min_extent(vx_var, 1), tvm.ir.Range.from_min_extent(vy_var, 1)] + ) + ] + writes = [tvm.tir.BufferRegion(A, [tvm.ir.Range.from_min_extent(vx_var, 1)])] + match_buffer_region = tvm.tir.MatchBufferRegion( + match_buffer, tvm.tir.BufferRegion(B, [tvm.ir.Range(0, 16), tvm.ir.Range(0, 16)]) + ) + + block = tvm.tir.Block( + [vx, vy], + reads, + writes, + "block", + body, + init=init_body, + alloc_buffers=[alloc_buffer], + match_buffers=[match_buffer_region], + annotations={"attr_key": "attr_value"}, + ) + + # Checking Block + assert isinstance(block, tvm.tir.Block) + # Checking iter_vars + assert block.iter_vars[0] == vx + assert block.iter_vars[1] == vy + # Checking reads/writes region + assert isinstance(block.reads[0], tvm.tir.BufferRegion) + assert block.reads[0].buffer == B + assert block.reads[0].region[0].min == vx_var + assert block.reads[0].region[1].min == vy_var + assert isinstance(block.writes[0], tvm.tir.BufferRegion) + assert block.writes[0].buffer == A + assert block.writes[0].region[0].min == vx_var + assert block.writes[0].region[0].extent == 1 + # Checking name_hint + assert block.name_hint == "block" + # Checking body + assert block.body == body + # Checking init + assert block.init == init_body + # Checking alloc_buffers + assert block.alloc_buffers[0] == alloc_buffer + # Checking match_buffers + assert block.match_buffers[0].buffer == match_buffer + assert isinstance(block.match_buffers[0].source, tvm.tir.BufferRegion) + assert block.match_buffers[0].source.buffer == B + assert block.match_buffers[0].source.region[0].min == 0 + assert block.match_buffers[0].source.region[0].extent == 16 + + # Checking BlockRealize + block_realize = tvm.tir.BlockRealize([x, y], tvm.tir.const(True, "bool"), block) + assert isinstance(block_realize, tvm.tir.BlockRealize) + assert block_realize.iter_values[0] == x + assert block_realize.iter_values[1] == y + assert block_realize.predicate == tvm.tir.const(True, "bool") + assert block_realize.block == block + + # make sure we can print + str(block) + str(block_realize) + + if __name__ == "__main__": test_intimm_cond() test_buffer_load_store() @@ -389,3 +470,4 @@ def test_intimm_cond(): test_isnan() test_equality() test_equality_string_imm() + test_block_blockrealize()