Skip to content

Conversation

@liangW-intellif
Copy link
Contributor

@liangW-intellif liangW-intellif commented Oct 11, 2022

Hi, this PR ported the rolling_buffer primitive from TE schedule to TensorIR schedule, refer to [RFC] Introducing a ‘rolling_buffer’ scheduling primitive.
The primitive performs the following steps to transform the target buffer into a 'rolling buffer':

  1. Collect bound overlaps on the target buffer, and select the outermost rollable axis appeared in the block's loop nest as the 'rolling axis'.
  2. Append block predicate to the producer block of the target buffer to avoid recomputation.
  3. Use modulo arithmetic to modify the target buffer's read and load indices to circularize the buffer along the rolling dimension.

Note: The region_cover property of the consumer block of the target buffer will become false.

Example

  • Before
def before_rolling_buffer(
    A: T.Buffer[(12, 12), "int8"], C: T.Buffer[(8, 8), "int8"]
) -> None:
    # body
    # with T.block("root")
    B = T.alloc_buffer([10, 10], dtype="int8")
    for i0, i1 in T.grid(2, 2):
        for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
            with T.block("B"):
                ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                B[ax0_1, ax1_1] = T.max(B[ax0_1, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1])
        for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
            with T.block("C"):
                ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 + rv0, ax1_1 + rv1])
  • After sch.rolling_buffer(sch.get_block("B"), buffer_index=0)
@T.prim_func
def after_rolling_buffer(
    A: T.Buffer[(12, 12), "int8"],
    C: T.Buffer[(8, 8), "int8"]
) -> None:
    # body
    # with T.block("root")
    B = T.alloc_buffer([6, 10], dtype="int8")
    for i0, i1 in T.grid(2, 2):
        for ax0, ax1, ax2, ax3 in T.grid(6, 6, 3, 3):
            with T.block("B"):
                T.where((i0 < 1 or 2 <= ax0) and (i1 < 1 or 2 <= ax1))
                ax0_1 = T.axis.spatial(10, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(10, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                B[ax0_1 % 6, ax1_1] = T.max(B[ax0_1 % 6, ax1_1], A[ax0_1 + rv0, ax1_1 + rv1])
        for ax0, ax1, ax2, ax3 in T.grid(4, 4, 3, 3):
            with T.block("C"):
                ax0_1 = T.axis.spatial(8, i0 * 4 + ax0)
                ax1_1 = T.axis.spatial(8, i1 * 4 + ax1)
                rv0, rv1 = T.axis.remap("RR", [ax2, ax3])
                C[ax0_1, ax1_1] = T.max(C[ax0_1, ax1_1], B[ax0_1 % 6 + rv0, ax1_1 + rv1])

Difference from TE rolling_buffer

TIR rolling_buffer will only select a dimension with a positive bound overlap as rolling dimension, consider the following example, the collected bound overlap for buffer B is [0, 0, 2, 0].

@T.prim_func
def before(
    A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"]
):
    B = T.alloc_buffer([1, 12, 14, 16], dtype="int8")
    for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1):
        for ax0, ax1, ax2 in T.grid(4, 6, 16):
            with T.block("B"):
                ax0_1 = T.axis.spatial(1, 0)
                ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0)
                ax2_1 = T.axis.spatial(12, i2_0 * 4 + ax1)
                ax3 = T.axis.spatial(16, ax2)
                B[ax0_1, ax1_1, ax2_1, ax3] = A[ax0_1, ax1_1, ax2_1, ax3]
        for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3):
            with T.block("C"):
                ax0 = T.axis.spatial(1, i0_0 + i0_1)
                ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1)
                ax2 = T.axis.spatial(12, i2_0 * 4 + i2_1)
                ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1)
                rv0, rv1 = T.axis.remap("RR", [i4, i5])
                C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, ax2 + rv1, ax3])

For the logic of TE rolling_buffer, i1_0 will be selected as the rolling axis and its range will be folded to [0, 4] to compact and minimize the buffer size. But for TensorIR, buffer region compaction will be performed by CompactBufferAllocation pass, so the primitive will select i2_0 with a positive bound overlap to be the rolling axis to circularize the buffer.

@T.prim_func
def after(A: T.Buffer[(1, 12, 14, 16), "int8"], C: T.Buffer[(1, 12, 12, 16), "int8"]) -> None:
    # body
    # with T.block("root")
    B = T.alloc_buffer([1, 12, 6, 16], dtype="int8")
    for i0_0, i1_0, i2_0, i3_0 in T.grid(1, 3, 3, 1):
        for ax0, ax1, ax2 in T.grid(4, 6, 16):
            with T.block("B"):
                T.where(i2_0 < 1 or 2 <= ax1)
                ax0_1 = T.axis.spatial(1, 0)
                ax1_1 = T.axis.spatial(12, i1_0 * 4 + ax0)
                ax2_1 = T.axis.opaque(12, i2_0 * 4 + ax1)
                ax3 = T.axis.spatial(16, ax2)
                B[ax0_1, ax1_1, ax2_1 % 6, ax3] = T.max(B[ax0_1, ax1_1, ax2_1 % 6, ax3], A[ax0_1, ax1_1, ax2_1, ax3])
        for i0_1, i1_1, i2_1, i3_1, i4, i5 in T.grid(1, 4, 4, 16, 1, 3):
            with T.block("C"):
                ax0 = T.axis.spatial(1, i0_0 + i0_1)
                ax1 = T.axis.spatial(12, i1_0 * 4 + i1_1)
                ax2 = T.axis.opaque(12, i2_0 * 4 + i2_1)
                ax3 = T.axis.spatial(16, i3_0 * 16 + i3_1)
                rv0, rv1 = T.axis.remap("RR", [i4, i5])
                C[ax0, ax1, ax2, ax3] = T.max(C[ax0, ax1, ax2, ax3], B[ax0, ax1 + rv0, (ax2 + rv1) % 6, ax3])

Note that if the acess region of the target buffer does not have a positive bound overlap in any dimension, the primitive would fail and throw an error, please let me know if this is inappropriate.
cc @wrongtest-intellif

@liangW-intellif liangW-intellif force-pushed the tir_rolling_buffer branch 3 times, most recently from e0c68de to 1568fbf Compare October 12, 2022 02:22
@tvm-bot
Copy link
Collaborator

tvm-bot commented Oct 12, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@wrongtest-intellif
Copy link
Contributor

cc @mbaret @junrushao

@liangW-intellif liangW-intellif force-pushed the tir_rolling_buffer branch 2 times, most recently from a76af4e to 8f60dae Compare October 18, 2022 09:08
@areusch areusch added needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it and removed needs-triage PRs or issues that need to be investigated by maintainers to find the right assignees to address it labels Oct 19, 2022
@Hzfengsy
Copy link
Member

cc @mbaret

Copy link
Contributor

@wrongtest-intellif wrongtest-intellif left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks for your efforts!

@wrongtest-intellif wrongtest-intellif merged commit 2c1fecd into apache:main Nov 1, 2022
nverke pushed a commit to nverke/tvm that referenced this pull request Nov 3, 2022
apache#13033)

* [TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR

* Address review comments

* Add dependency checks
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 10, 2022
apache#13033)

* [TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR

* Address review comments

* Add dependency checks
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
apache#13033)

* [TIR][Primitive] Support rolling_buffer schedule primitive in TensorIR

* Address review comments

* Add dependency checks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants