Skip to content

Conversation

@nverke
Copy link
Contributor

@nverke nverke commented Nov 5, 2022

…odes match_buffer statements when validating writes

Previously this check did not take into account any match_buffer statements and consequently would fail for tensorized schedules. Now it takes these into account when possible.

cc @vinx13 @Hzfengsy

@tvm-bot
Copy link
Collaborator

tvm-bot commented Nov 5, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

@github-actions github-actions bot requested a review from vinx13 November 5, 2022 00:15
@Hzfengsy
Copy link
Member

Hzfengsy commented Nov 5, 2022

Thanks @nverke for the enhancement. Could you please add a regression test case?

@nverke nverke force-pushed the match_buffer_check branch from b15818e to a85aecd Compare November 7, 2022 19:28
@nverke
Copy link
Contributor Author

nverke commented Nov 7, 2022

Thanks @nverke for the enhancement. Could you please add a regression test case?

Just added!

@github-actions github-actions bot requested a review from Hzfengsy November 7, 2022 19:32
@nverke
Copy link
Contributor Author

nverke commented Nov 7, 2022

@tvm-bot rerun

ICHECK(buffer_written.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature";
const auto* body_block = block->body.as<BlockRealizeNode>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checking block->body is not sufficient. It is possible the inner block has outer loops.
consider the case:

block1
  for …
     block2
        match_buffer
        buffer_store

In this case, we need to the parent block of the store statement

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I am not sure if I understand. Are you talking about a situation like this?

@T.prim_func
def nested_reduction_loop_with_match_buffers(
    in0: T.Buffer[(4, 4, 4), "int8"],
    in1: T.Buffer[(4, 4, 4), "int8"],
    out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
    # body
    # with T.block("root")
    for y in T.serial(4):
        with T.block("C"):
            T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
            T.writes(out[y, 0:4, 0:4])
            for x in T.serial(4):
                with T.block("C"):
                    T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
                    T.writes(out[y, x, 0:4])
                    A = T.match_buffer(in0[y, x, 0:4], [4], dtype="int8", offset_factor=1)
                    B = T.match_buffer(in1[y, x, 0:4], [4], dtype="int8", offset_factor=1)
                    C = T.match_buffer(out[y, x, 0:4], [4], dtype="int8", offset_factor=1)
                    A_i8x4: T.int8x4 = A[0:4]
                    A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
                    B_i8x4: T.int8x4 = B[0:4]
                    B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
                    C[0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")

My understanding is that this check is on the level of the store statement and already has collected all of the write regions for the parent loops so just adding the regions that the match buffers check should be enough.

Alternatively are you referring to something like this?

@T.prim_func
def nested_reduction_loop_with_match_buffers(
    in0: T.Buffer[(4, 4, 4), "int8"],
    in1: T.Buffer[(4, 4, 4), "int8"],
    out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
    # body
    # with T.block("root")
    for y in T.serial(4):
        with T.block("C"):
            T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
            T.writes(out[y, 0:4, 0:4])
            A = T.match_buffer(in0[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
            B = T.match_buffer(in1[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
            C = T.match_buffer(out[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
            for x in T.serial(4):
                with T.block("C"):
                    T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
                    T.writes(out[y, x, 0:4])
                    A_i8x4: T.int8x4 = A[x, 0:4]
                    A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
                    B_i8x4: T.int8x4 = B[x, 0:4]
                    B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
                    C[x, 0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")

Here I believe we are still able to pickup the match buffers from the body block.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm referring to your former example. In this case, if (body_block) will always be false, and match_buffer will not be checked.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright made changes accordingly and added a test to validate this case as well as another test that on the other scenario even though that was never an issue.

@tmoreau89 tmoreau89 merged commit 7cd203d into apache:main Nov 10, 2022
@tmoreau89
Copy link
Contributor

Thank you @nverke @vinx13 - PR has been merged!

Comment on lines +611 to +619
xr = T.axis.reduce(4, x)
with T.init():
for i in T.serial(4):
with T.block("C_init"):
ii = T.axis.spatial(4, i)
T.reads()
T.writes(out[yi, ii])
out[yi, ii] = 0
with T.block("C"):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm late to the party! Just a quick note: this TVMScript is not valid TIR and I happened to detect it when using the new TVMScript parser which checks more carefully :-)

More specifically, both T.init() and T.axis.reduce should be placed immediately under a TIR block, while line 611 and line 612 are not :-(

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we print out the testcase using nested_reduction_loop_with_outer_match_buffers.show(), then the TIR looks like:

# from tvm.script import tir as T
@T.prim_func
def func(in0: T.Buffer[(4, 16), "int8"], in1: T.Buffer[(4, 16), "int8"], out: T.Buffer[(4, 4), "int32"]):
    # body
    # with T.block("root")
    for y in T.serial(4):
        with T.block("C"):
            yi = T.axis.spatial(4, y)
            xr = T.axis.reduce(4, x)
            T.reads(in0[yi, 0 : 16], in1[yi, 0 : 16])
            T.writes(out[yi, 0 : 4])
            A = T.match_buffer(in0[yi, 0 : 16], [16], dtype="int8", offset_factor=1)
            B = T.match_buffer(in1[yi, 0 : 16], [16], dtype="int8", offset_factor=1)
            C = T.match_buffer(out[yi, 0 : 4], [4], dtype="int32", offset_factor=1)
            with T.init():
                for i in T.serial(4):
                    with T.block("C_init"):
                        ii = T.axis.spatial(4, i)
                        T.reads()
                        T.writes(out[yi, ii])
                        out[yi, ii] = 0
            for x in T.serial(4):
                with T.block("C"):
                    T.reads(out[yi, xr], in0[yi, yi * 4 + xr : yi * 4 + xr + 4], in1[yi, yi * 4 + xr : yi * 4 + xr + 4])
                    T.writes(out[yi, xr])
                    A_i8x4: T.int8x4 = A[yi * 4 + xr:yi * 4 + xr + 4]
                    A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
                    B_i8x4: T.int8x4 = B[yi * 4 + xr:yi * 4 + xr + 4]
                    B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
                    C[xr] = A_i32 + B_i32 + C[xr]

where we may see some use-before-def issues

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i would love to temporarily exclude this particular TVMScript from testing, but happy to merge it back if you have a follow-up PR to fix :-)

Copy link
Contributor Author

@nverke nverke Nov 11, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh interesting! I see so the reduction axis is tied to the block despite being within the loop. I can follow up with a commit in a few days or you can remove the test if its causing issues.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
apache#13301)

* [TIR] Update ReductionIterNotIndexOutputBuffer to check BlockRealizeNodes match_buffer statements when validating writes

* Add test to verify that tensorized blocks are properly validated

* update to take into account all match buffer regions.

* lint
@nverke nverke deleted the match_buffer_check branch January 13, 2023 23:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants