Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 247 additions & 1 deletion include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -862,7 +862,7 @@ class For : public Stmt {
};

/*!
* \brief A prefetch hint for abuffer
* \brief A prefetch hint for a buffer
*/
class PrefetchNode : public StmtNode {
public:
Expand Down Expand Up @@ -905,6 +905,252 @@ class Prefetch : public Stmt {
TVM_DEFINE_NOTNULLABLE_OBJECT_REF_METHODS(Prefetch, Stmt, PrefetchNode);
};

/*!
* \brief Representing the region of multi-dimensional buffer access.
*/
class BufferRegionNode : public Object {
public:
/*! \brief The buffer of the buffer region. */
Buffer buffer;
/*! \brief The region array of the buffer region. */
Array<Range> region;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("region", &region);
}

bool SEqualReduce(const BufferRegionNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(region, other->region);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(region);
}

static constexpr const char* _type_key = "tir.BufferRegion";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(BufferRegionNode, Object);
};

/*!
* \brief Managed reference to BufferRegionNode.
* \sa BufferRegionNode
*/
class BufferRegion : public ObjectRef {
public:
TVM_DLL explicit BufferRegion(Buffer buffer, Array<Range> region);

/*!
* \brief Create a BufferRegion which is full region of the given buffer..
* \param buffer The buffer to generate full BufferRegion.
* \return The BufferRegion which covers all region of the given buffer
*/
TVM_DLL static BufferRegion FullRegion(Buffer buffer);

TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode);
};

/*!
* \brief Match introduces a constraint that the source buffer region can be remapped to the data
* layout specified by the buffer field. The constraint can be checked in later part of lowering (or
* optionally during runtime).
*
* MatchBufferRegion provides a mechanism to represent data layout and compactness constraints in
* low-level hardware primitives in the IR and defer the check after the sequence of
* transformations.
*/
class MatchBufferRegionNode : public Object {
public:
/*! \brief The target buffer. */
Buffer buffer;
/*! \brief The source buffer region. */
BufferRegion source;

void VisitAttrs(AttrVisitor* v) {
v->Visit("buffer", &buffer);
v->Visit("source", &source);
}

bool SEqualReduce(const MatchBufferRegionNode* other, SEqualReducer equal) const {
return equal(buffer, other->buffer) && equal(source, other->source);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(buffer);
hash_reduce(source);
}

static constexpr const char* _type_key = "tir.MatchBufferRegion";
static constexpr const bool _type_has_method_sequal_reduce = true;
static constexpr const bool _type_has_method_shash_reduce = true;
TVM_DECLARE_FINAL_OBJECT_INFO(MatchBufferRegionNode, Object);
};

/*!
* \brief Managed reference to MatchBufferRegionNode.
* \sa MatchBufferRegionNode
*/
class MatchBufferRegion : public ObjectRef {
public:
TVM_DLL explicit MatchBufferRegion(Buffer buffer, BufferRegion source);

TVM_DEFINE_OBJECT_REF_METHODS(MatchBufferRegion, ObjectRef, MatchBufferRegionNode);
};

/*!
* \brief A block is a basic schedule unit in TIR.
* \note Block's body is parameterized by iter vars.
* \code
*
* with tir.block([extent0, extent1, ...], name) as [v0, v1, ...]:
* tir.bind(v0, value0)
* tir.bind(v1, value1)
* ...
* tir.reads([buffer0[start:end, ...], ...])
* tir.writes([buffer1[start:end, ...], ...])
* tir.where(predicate)
* buffer2 = tir.alloc_buffer(shape, dtype)
* buffer3 = tir.match_buffer(source_buffer[start:end, ...])
* tir.attr({attr_key: attr_value, ...})
* with tir.init():
* // init body
* // body
*
* \endcode
*/
class BlockNode : public StmtNode {
public:
/*! \brief The variables of the block. */
Array<IterVar> iter_vars;
/*! \brief The read buffer regions of the block. */
Array<BufferRegion> reads;
/*! \brief The write buffer regions of the block. */
Array<BufferRegion> writes;
/*! \brief The name_hint of the block. */
String name_hint;
/*! \brief The body of the block. */
Stmt body;
/*!
* \brief The init statement is executed during the first iteration of reduction loops in a
* reduction block. The optional init field allows us to represent initialization and
* reduction update in a single block and transform them collectively.
* We also provide primitives to decompose the init into a separate block during scheduling.
* Init field is `NullOpt` if there is no reduction iter_vars
*/
Optional<Stmt> init;
/*! \brief The buffer allocated in the block. */
Array<Buffer> alloc_buffers;
/*! \brief The match buffer regions. */
Array<MatchBufferRegion> match_buffers;
/*! \brief The annotation of the block. */
Map<String, ObjectRef> annotations;

void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_vars", &iter_vars);
v->Visit("reads", &reads);
v->Visit("writes", &writes);
v->Visit("name_hint", &name_hint);
v->Visit("body", &body);
v->Visit("init", &init);
v->Visit("alloc_buffers", &alloc_buffers);
v->Visit("match_buffers", &match_buffers);
v->Visit("annotations", &annotations);
}

bool SEqualReduce(const BlockNode* other, SEqualReducer equal) const {
// Need first reduce iter_vars, alloc_buffers and match_buffers to define new vars
return equal.DefEqual(iter_vars, other->iter_vars) &&
equal(alloc_buffers, other->alloc_buffers) &&
equal(match_buffers, other->match_buffers) && equal(reads, other->reads) &&
equal(writes, other->writes) && equal(body, other->body) && equal(init, other->init) &&
equal(annotations, other->annotations);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce.DefHash(iter_vars);
hash_reduce(alloc_buffers);
hash_reduce(match_buffers);
hash_reduce(reads);
hash_reduce(writes);
hash_reduce(body);
hash_reduce(init);
hash_reduce(annotations);
}

static constexpr const char* _type_key = "tir.Block";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockNode, StmtNode);
};

/*!
* \brief Managed reference to BlockNode.
* \sa BlockNode
*/
class Block : public Stmt {
public:
TVM_DLL explicit Block(Array<IterVar> iter_vars, Array<BufferRegion> reads,
Array<BufferRegion> writes, String name_hint, Stmt body,
Optional<Stmt> init = NullOpt,
Array<Buffer> alloc_buffers = Array<Buffer>(),
Array<MatchBufferRegion> match_buffers = Array<MatchBufferRegion>(),
Map<String, ObjectRef> annotations = Map<String, ObjectRef>(),
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(Block, Stmt, BlockNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockNode);
};

/*!
* \brief A block realization node represents execution of the block at the binding values.
*/
class BlockRealizeNode : public StmtNode {
public:
/*! \brief The corresponding values of the iter vars. */
Array<PrimExpr> iter_values;
/*!
* \brief The predicate of the block realization, the block will only be executed when the
* predicate is true.
*/
PrimExpr predicate;
/*! \brief The block to be realized. */
Block block;

void VisitAttrs(AttrVisitor* v) {
v->Visit("iter_values", &iter_values);
v->Visit("predicate", &predicate);
v->Visit("block", &block);
}

bool SEqualReduce(const BlockRealizeNode* other, SEqualReducer equal) const {
return equal(iter_values, other->iter_values) && equal(predicate, other->predicate) &&
equal(block, other->block);
}

void SHashReduce(SHashReducer hash_reduce) const {
hash_reduce(iter_values);
hash_reduce(predicate);
hash_reduce(block);
}

static constexpr const char* _type_key = "tir.BlockRealize";
TVM_DECLARE_FINAL_OBJECT_INFO(BlockRealizeNode, StmtNode);
};

/*!
* \brief Managed reference to BlockRealizeNode
* \sa BlockRealizeNode
*/
class BlockRealize : public Stmt {
public:
TVM_DLL explicit BlockRealize(Array<PrimExpr> iter_values, PrimExpr predicate, Block block,
Span span = Span());

TVM_DEFINE_OBJECT_REF_METHODS(BlockRealize, Stmt, BlockRealizeNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BlockRealizeNode);
};

/*! \brief namespace of possible attribute sin AttrStmt.attr_key */
namespace attr {
// The above attr does not pass to ir stage.
Expand Down
8 changes: 8 additions & 0 deletions include/tvm/tir/stmt_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
virtual R VisitStmt_(const PrefetchNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const SeqStmtNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const EvaluateNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmt_(const BlockRealizeNode* op, Args... args) STMT_FUNCTOR_DEFAULT;
virtual R VisitStmtDefault_(const Object* op, Args...) {
LOG(FATAL) << "Do not have a default for " << op->GetTypeKey();
return R();
Expand All @@ -119,6 +121,8 @@ class StmtFunctor<R(const Stmt& n, Args... args)> {
IR_STMT_FUNCTOR_DISPATCH(EvaluateNode);
IR_STMT_FUNCTOR_DISPATCH(BufferStoreNode);
IR_STMT_FUNCTOR_DISPATCH(BufferRealizeNode);
IR_STMT_FUNCTOR_DISPATCH(BlockNode);
IR_STMT_FUNCTOR_DISPATCH(BlockRealizeNode);
return vtable;
}
};
Expand Down Expand Up @@ -158,6 +162,8 @@ class TVM_DLL StmtVisitor : protected StmtFunctor<void(const Stmt&)> {
void VisitStmt_(const PrefetchNode* op) override;
void VisitStmt_(const SeqStmtNode* op) override;
void VisitStmt_(const EvaluateNode* op) override;
void VisitStmt_(const BlockNode* op) override;
void VisitStmt_(const BlockRealizeNode* op) override;
};

/*!
Expand Down Expand Up @@ -249,6 +255,8 @@ class TVM_DLL StmtMutator : protected StmtFunctor<Stmt(const Stmt&)> {
Stmt VisitStmt_(const PrefetchNode* op) override;
Stmt VisitStmt_(const SeqStmtNode* op) override;
Stmt VisitStmt_(const EvaluateNode* op) override;
Stmt VisitStmt_(const BlockNode* op) override;
Stmt VisitStmt_(const BlockRealizeNode* op) override;
/*!
* \brief Alternative advance method for SeqStmtNode.
*
Expand Down
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from .stmt import BufferStore, BufferRealize, Store, ProducerStore, Allocate, AttrStmt
from .stmt import ProducerRealize, SeqStmt
from .stmt import IfThenElse, Evaluate, Prefetch, stmt_seq, stmt_list
from .stmt import BufferRegion, MatchBufferRegion, Block, BlockRealize

from .function import PrimFunc

Expand Down
Loading