Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions include/tvm/script/ir_builder/tir/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,57 @@ class BlockFrame : public TIRFrame {
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BlockFrame, TIRFrame, BlockFrameNode);
};

/*!
* \brief A frame that represents the for loop.
*
* \sa ForFrame
*/
class ForFrameNode : public TIRFrameNode {
public:
/*!
* \brief Functions that generate loop nests.
* \param loop_vars The loop variables, from outer to inner
* \param loop_extents The loop extents that correspond to loop variables
* \param loop_body The loop body
* \return A stmt, the loop nest
*/
using FMakeForLoop = runtime::TypedPackedFunc<tvm::tir::Stmt(
Array<tvm::tir::Var> loop_vars, Array<Range> loop_extents, tvm::tir::Stmt loop_body)>;
/*! \brief The loop variable. */
Array<tvm::tir::Var> vars;
/*! \brief The domains of iteration. */
Array<Range> doms;
/*! \brief The for loop generating function. */
FMakeForLoop f_make_for_loop;

void VisitAttrs(tvm::AttrVisitor* v) {
TIRFrameNode::VisitAttrs(v);
v->Visit("vars", &vars);
v->Visit("doms", &doms);
// `f_make_for_loop` is not visited.
}

static constexpr const char* _type_key = "script.ir_builder.tir.ForFrame";
TVM_DECLARE_FINAL_OBJECT_INFO(ForFrameNode, TIRFrameNode);

public:
/*!
* \brief The method called when exiting RAII scope.
* \sa tvm::support::With
*/
void ExitWithScope() final;
};

/*!
* \brief Managed reference to ForFrameNode.
*
* \sa ForFrameNode
*/
class ForFrame : public TIRFrame {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(ForFrame, TIRFrame, ForFrameNode);
};

/*!
* \brief A frame that represents the assert statement. Proceeds if the condition is true,
* otherwise aborts with the message.
Expand Down
53 changes: 53 additions & 0 deletions include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,59 @@ void PreflattenedBuffer(Buffer postflattened_buffer, Array<PrimExpr> shape,
*/
BlockFrame Block(String name, bool no_realize = false);

/*!
* \brief The serial For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \return The ForFrame.
*/
ForFrame Serial(PrimExpr start, PrimExpr stop,
Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief The parallel For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \return The ForFrame.
*/
ForFrame Parallel(PrimExpr start, PrimExpr stop,
Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief The vectorized For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \return The ForFrame.
*/
ForFrame Vectorized(PrimExpr start, PrimExpr stop,
Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief The unrolled For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param annotations The optional annotations of the For statement.
* \return The ForFrame.
*/
ForFrame Unroll(PrimExpr start, PrimExpr stop,
Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief The thread-binding For statement.
* \param start The minimum value of iteration.
* \param stop The maximum value of iteration.
* \param thread The thread for loop variable to bind.
* \param annotations The optional annotations of the For statement.
* \return The ForFrame.
*/
ForFrame ThreadBinding(PrimExpr start, PrimExpr stop, String thread,
Optional<Map<String, ObjectRef>> annotations = NullOpt);
/*!
* \brief The grid For statement.
* \param extents The extents of the iteration.
* \return The ForFrame.
*/
ForFrame Grid(Array<PrimExpr> extents);

/*!
* \brief Evaluate the input expression.
* \param value The input expression to evaluate.
Expand Down
9 changes: 9 additions & 0 deletions python/tvm/script/ir_builder/tir/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@
# specific language governing permissions and limitations
# under the License.
"""IRBuilder for TIR"""
from typing import List, Union

from tvm._ffi import register_object as _register_object
from tvm.tir import Var

from ..base import IRBuilderFrame

Expand All @@ -34,3 +36,10 @@ class PrimFuncFrame(TIRFrame):
@_register_object("script.ir_builder.tir.BlockFrame")
class BlockFrame(TIRFrame):
...


@_register_object("script.ir_builder.tir.ForFrame")
class ForFrame(TIRFrame):
def __enter__(self) -> Union[Var, List[Var]]:
super().__enter__()
return self.vars if len(self.vars) > 1 else self.vars[0]
172 changes: 172 additions & 0 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,172 @@ def block(name: str = "", no_realize: bool = False) -> frame.BlockFrame:
return _ffi_api.Block(name, no_realize) # pylint: disable=no-member # type: ignore


def serial(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
"""The serial For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = 0
return _ffi_api.Serial(start, stop, annotations) # pylint: disable=no-member # type: ignore


def parallel(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
"""The parallel For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = 0
return _ffi_api.Parallel(start, stop, annotations) # pylint: disable=no-member # type: ignore


def vectorized(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
"""The vectorized For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = 0
return _ffi_api.Vectorized(start, stop, annotations) # pylint: disable=no-member # type: ignore


def unroll(
start: PrimExpr, stop: PrimExpr = None, *, annotations: Dict[str, Any] = None
) -> frame.ForFrame:
"""The unrolled For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if stop is None:
stop = start
start = 0
return _ffi_api.Unroll(start, stop, annotations) # pylint: disable=no-member # type: ignore


def thread_binding(
start: PrimExpr,
stop: PrimExpr = None,
thread: str = None,
*,
annotations: Dict[str, Any] = None,
) -> frame.ForFrame:
"""The thread-binding For statement.

Parameters
----------
start : PrimExpr
The minimum value of iteration.

stop : PrimExpr
The maximum value of iteration.

thread : str
The thread for loop variable to bind.

annotations : Dict[str, Any]
The optional annotations of the For statement.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
if thread is None:
if not isinstance(stop, str):
raise ValueError("Thread cannot be None for thread_binding")
thread = stop
stop = start
start = 0
elif stop is None:
stop = start
start = 0
return _ffi_api.ThreadBinding( # pylint: disable=no-member # type: ignore
start, stop, thread, annotations
)


def grid(*extents: PrimExpr) -> frame.ForFrame:
"""The grid For statement.

Parameters
----------
extents : PrimExpr
The extents of the iteration.

Returns
-------
res : frame.ForFrame
The ForFrame.
"""
return _ffi_api.Grid(extents) # pylint: disable=no-member # type: ignore


def evaluate(value: PrimExpr) -> None:
"""Evaluate the input expression.

Expand Down Expand Up @@ -677,6 +843,12 @@ def var(dtype, name="") -> Var:
"match_buffer",
"preflattened_buffer",
"block",
"serial",
"parallel",
"vectorized",
"unroll",
"thread_binding",
"grid",
"evaluate",
"int8",
"int16",
Expand Down
6 changes: 6 additions & 0 deletions src/script/ir_builder/tir/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,15 @@ void BlockFrameNode::ExitWithScope() {
}
}

void ForFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(this->f_make_for_loop(vars, doms, AsStmt(stmts)));
}

TVM_REGISTER_NODE_TYPE(TIRFrameNode);
TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
TVM_REGISTER_NODE_TYPE(BlockFrameNode);
TVM_REGISTER_NODE_TYPE(ForFrameNode);

} // namespace tir
} // namespace ir_builder
Expand Down
Loading