Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class ScheduleNode : public runtime::Object {
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is not modified;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed
*/
virtual Schedule Copy() const = 0;
Expand Down Expand Up @@ -220,6 +220,43 @@ class ScheduleNode : public runtime::Object {
*/
virtual Array<LoopRV> Split(const LoopRV& loop_rv, const Array<Optional<ExprRV>>& factors) = 0;
/******** Schedule: Manipulate ForKind ********/
/*!
* \brief Parallelize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be parallelized
*/
virtual void Parallel(const LoopRV& loop_rv) = 0;
/*!
* \brief Vectorize the input loop. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, the loop can only be contained in data-parallel block iters'
* bindings
* \param loop_rv The loop to be vectorized
*/
virtual void Vectorize(const LoopRV& loop_rv) = 0;
/*!
* \brief Bind the input loop to the given thread axis. It requires:
* 1) The scope block that the loop is in should have stage-pipeline property
* 2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
* bindings
* 3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can only
* be contained in data-parallel block iter and reduction block iters' bindings. Otherwise the
* loop can only be contained in data-parallel block iters' bindings
* \param loop_rv The loop to be bound to the thread axis
* \param thread_axis The thread axis to be bound to the loop
*/
virtual void Bind(const LoopRV& loop_rv, const String& thread_axis) = 0;
/*!
* \brief Unroll the input loop. It requires nothing
* \param loop_rv The loop to be unrolled
*/
virtual void Unroll(const LoopRV& loop_rv) = 0;
/******** Schedule: Insert cache stages ********/
/******** Schedule: Compute location ********/
/*!
Expand Down
12 changes: 12 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,18 @@ TVM_DLL Pass LowerMatchBuffer();
*/
TVM_DLL Pass FlattenBuffer();

/*!
* \brief Unify all the thread bindings for "blockIdx.x/y/z", "threadIdx.x/y/z", and
* "vthread.x/y/z". Before the unification, two vars that are bound to a thread axis (e.g.,
* "threadIdx.x") use different IterVars and variables in their AttrStmts. After the
* unification, we use a consolidated IterVar and a variable for them.
* \return The pass.
* \note `vthread` is a legacy behavior that will be deprecated, though thread bindings of `vthread`
* are still also unified in this pass. Please use `vthread.x`, `vthread.y` and `vthread.z`
* instead.
*/
TVM_DLL Pass UnifyThreadBinding();

/*!
* A pass to merge multiple TIR-level dynamic shared memory allocations into one
*/
Expand Down
240 changes: 234 additions & 6 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def copy(self) -> "Schedule":
* guaranteeing that
* 1) SRef tree is completely reconstructed;
* 2) The IRModule being scheduled is untouched;
* 3) All the random variables are valid in the copy, pointing to the correpsonding sref
* 3) All the random variables are valid in the copy, pointing to the corresponding sref
* reconstructed

Returns
Expand Down Expand Up @@ -226,7 +226,7 @@ def get(
Returns
-------
result : Optional[Union[int, Block, For]]
The correpsonding result
The corresponding result
"""
if isinstance(rand_var_or_sref, StmtSRef):
return rand_var_or_sref.stmt
Expand All @@ -236,7 +236,7 @@ def get(
return result

def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Optional[StmtSRef]:
"""Returns the correpsonding sref to the given
"""Returns the corresponding sref to the given
1) LoopRV
2) BlockRV
3) Block
Expand All @@ -250,7 +250,7 @@ def get_sref(self, rand_var_or_stmt: Union[BlockRV, LoopRV, Block, For]) -> Opti
Returns
-------
result : Optional[StmtSRef]
The correpsonding result
The corresponding result
"""
return _ffi_api.ScheduleGetSRef( # type: ignore # pylint: disable=no-member
self, rand_var_or_stmt
Expand Down Expand Up @@ -413,7 +413,7 @@ def before_split(a: ty.handle, b: ty.handle) -> None:
with tir.block([128, 128], "B") as [vi, vj]:
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do fuse:
Create the schedule and do split:

.. code-block:: python

Expand Down Expand Up @@ -444,6 +444,234 @@ def after_split(a: ty.handle, b: ty.handle) -> None:

########## Schedule: Manipulate ForKind ##########

def parallel(self, loop: LoopRV) -> None:
"""Parallelize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings

Parameters
----------
loop : LoopRV
The loop to be parallelized

Examples
--------

Before parallel, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do parallel:

.. code-block:: python

sch = tir.Schedule(before_parallel)
i, j = sch.get_loops(sch.get_block("B"))
sch.parallel(i)

After applying parallel, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_parallel(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.parallel(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

"""
_ffi_api.ScheduleParallel(self, loop) # type: ignore # pylint: disable=no-member

def vectorize(self, loop: LoopRV) -> None:
"""Vectorize the input loop. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, the loop can only be contained in data-parallel block
iters' bindings

Parameters
----------
loop : LoopRV
The loop to be vectorized

Examples
--------

Before vectorize, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do vectorize:

.. code-block:: python

sch = tir.Schedule(before_vectorize)
i, j = sch.get_loops(sch.get_block("B"))
sch.vectorize(j)

After applying vectorize, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_vectorize(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.serial(0, 128):
for j in tir.vectorized(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

"""
_ffi_api.ScheduleVectorize(self, loop) # type: ignore # pylint: disable=no-member

def bind(self, loop: LoopRV, thread_axis: str) -> None:
"""Bind the input loop to the given thread axis. It requires:
1) The scope block that the loop is in should have stage-pipeline property
2) All the blocks under the loop are complete blocks or reduction blocks, and have affine
bindings
3) For each block under the loop, if the thread axis starts with "threadIdx`, the loop can
only be contained in data-parallel block iter and reduction block iters' bindings. Otherwise
the loop can only be contained in data-parallel block iters' bindings

Parameters
----------
loop : LoopRV
The loop to be bound to the thread axis
thread_axis : str
The thread axis to be bound to the loop. Possible candidates:
- blockIdx.x/y/z
- threadIdx.x/y/z
- vthread.x/y/z
- vthread (It is a legacy behavior that will be deprecated. Please use `vthread.x/y/z`
instead.)

Examples
--------

Before bind, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do bind:

.. code-block:: python

sch = tir.Schedule(before_bind)
i, j = sch.get_loops(sch.get_block("B"))
sch.bind(i, "blockIdx.x")
sch.bind(j, "threadIdx.x")

After applying bind, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_bind(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.thread_binding(0, 128, thread = "blockIdx.x"):
for j in tir.thread_binding(0, 128, thread = "threadIdx.x"):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

"""
_ffi_api.ScheduleBind(self, loop, thread_axis) # type: ignore # pylint: disable=no-member

def unroll(self, loop: LoopRV) -> None:
"""Unroll the input loop. It requires nothing

Parameters
----------
loop : LoopRV
The loop to be unrolled

Examples
--------

Before unroll, in TensorIR, the IR is:

.. code-block:: python

@tvm.script.tir
def before_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i, j in tir.grid(128, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

Create the schedule and do unroll:

.. code-block:: python

sch = tir.Schedule(before_unroll)
i, j = sch.get_loops(sch.get_block("B"))
sch.unroll(i)

After applying unroll, the IR becomes:

.. code-block:: python

@tvm.script.tir
def after_unroll(a: ty.handle, b: ty.handle) -> None:
A = tir.match_buffer(a, (128, 128))
B = tir.match_buffer(b, (128, 128))
for i in tir.unroll(0, 128):
for j in tir.serial(0, 128):
with tir.block([128, 128], "B") as [vi, vj]:
tir.bind(vi, i)
tir.bind(vj, j)
B[vi, vj] = A[vi, vj] * 2.0

"""
_ffi_api.ScheduleUnroll(self, loop) # type: ignore # pylint: disable=no-member

########## Schedule: Insert cache stages ##########

########## Schedule: Compute location ##########
Expand Down Expand Up @@ -581,7 +809,7 @@ def rfactor(self, loop: LoopRV, factor_axis: int) -> LoopRV:
RFactor is a schedule primitive that implements the transformation described above:
Given a block that writes to buffer `B`, it factorizes a loop of extent `n`.

For example, the pesudocode below accumulates `B[i] = sum(A[i, : , : ])`:
For example, the pseudocode below accumulates `B[i] = sum(A[i, : , : ])`:

.. code-block:: python

Expand Down
22 changes: 22 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,28 @@ def FlattenBuffer():
return _ffi_api.FlattenBuffer() # type: ignore


def UnifyThreadBinding():
"""Unify all the thread bindings for "blockIdx.x/y/z",
"threadIdx.x/y/z", and "vthread.x/y/z". Before the unification,
two vars that are bound to a thread axis (e.g., "threadIdx.x")
use different IterVars and variables in their AttrStmts. After
the unification, we use a consolidated IterVar and a variable
for them.

Returns
-------
fpass : tvm.transform.Pass
The result pass

Note
----
`vthread` is a legacy behavior that will be deprecated, though
thread bindings of `vthread` are still also unified in this
pass. Please use `vthread.x`, `vthread.y` and `vthread.z` instead.
"""
return _ffi_api.UnifyThreadBinding() # type: ignore


def MergeDynamicSharedMemoryAllocations():
"""This pass merges multiple TIR-level dynamic shared memory allocations
into one allocation.
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::UnifyThreadBinding());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down
Loading