-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] Update ReductionIterNotIndexOutputBuffer to check BlockRealizeN… #13301
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
|
Thanks @nverke for the enhancement. Could you please add a regression test case? |
…odes match_buffer statements when validating writes
b15818e to
a85aecd
Compare
Just added! |
|
@tvm-bot rerun |
src/tir/schedule/analysis/reducer.cc
Outdated
| ICHECK(buffer_written.count(store->buffer.get())) | ||
| << "ValueError: The buffer \"" << store->buffer | ||
| << "\" is written in the block but is not in the block's signature"; | ||
| const auto* body_block = block->body.as<BlockRealizeNode>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
checking block->body is not sufficient. It is possible the inner block has outer loops.
consider the case:
block1
for …
block2
match_buffer
buffer_store
In this case, we need to the parent block of the store statement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I am not sure if I understand. Are you talking about a situation like this?
@T.prim_func
def nested_reduction_loop_with_match_buffers(
in0: T.Buffer[(4, 4, 4), "int8"],
in1: T.Buffer[(4, 4, 4), "int8"],
out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
# body
# with T.block("root")
for y in T.serial(4):
with T.block("C"):
T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
T.writes(out[y, 0:4, 0:4])
for x in T.serial(4):
with T.block("C"):
T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
T.writes(out[y, x, 0:4])
A = T.match_buffer(in0[y, x, 0:4], [4], dtype="int8", offset_factor=1)
B = T.match_buffer(in1[y, x, 0:4], [4], dtype="int8", offset_factor=1)
C = T.match_buffer(out[y, x, 0:4], [4], dtype="int8", offset_factor=1)
A_i8x4: T.int8x4 = A[0:4]
A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
B_i8x4: T.int8x4 = B[0:4]
B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
C[0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
My understanding is that this check is on the level of the store statement and already has collected all of the write regions for the parent loops so just adding the regions that the match buffers check should be enough.
Alternatively are you referring to something like this?
@T.prim_func
def nested_reduction_loop_with_match_buffers(
in0: T.Buffer[(4, 4, 4), "int8"],
in1: T.Buffer[(4, 4, 4), "int8"],
out: T.Buffer[(4, 4, 4), "int8"],
) -> None:
# body
# with T.block("root")
for y in T.serial(4):
with T.block("C"):
T.reads(in0[y, 0:4, 0:4], in1[y, 0:4, 0:4])
T.writes(out[y, 0:4, 0:4])
A = T.match_buffer(in0[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
B = T.match_buffer(in1[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
C = T.match_buffer(out[y, 0:4, 0:4], [4, 4], dtype="int8", offset_factor=1)
for x in T.serial(4):
with T.block("C"):
T.reads(in0[y, x, 0:4], in1[y, x, 0:4])
T.writes(out[y, x, 0:4])
A_i8x4: T.int8x4 = A[x, 0:4]
A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
B_i8x4: T.int8x4 = B[x, 0:4]
B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
C[x, 0:4] = T.reinterpret(A_i32 + B_i32, dtype="int8x4")
Here I believe we are still able to pickup the match buffers from the body block.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm referring to your former example. In this case, if (body_block) will always be false, and match_buffer will not be checked.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alright made changes accordingly and added a test to validate this case as well as another test that on the other scenario even though that was never an issue.
| xr = T.axis.reduce(4, x) | ||
| with T.init(): | ||
| for i in T.serial(4): | ||
| with T.block("C_init"): | ||
| ii = T.axis.spatial(4, i) | ||
| T.reads() | ||
| T.writes(out[yi, ii]) | ||
| out[yi, ii] = 0 | ||
| with T.block("C"): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm late to the party! Just a quick note: this TVMScript is not valid TIR and I happened to detect it when using the new TVMScript parser which checks more carefully :-)
More specifically, both T.init() and T.axis.reduce should be placed immediately under a TIR block, while line 611 and line 612 are not :-(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we print out the testcase using nested_reduction_loop_with_outer_match_buffers.show(), then the TIR looks like:
# from tvm.script import tir as T
@T.prim_func
def func(in0: T.Buffer[(4, 16), "int8"], in1: T.Buffer[(4, 16), "int8"], out: T.Buffer[(4, 4), "int32"]):
# body
# with T.block("root")
for y in T.serial(4):
with T.block("C"):
yi = T.axis.spatial(4, y)
xr = T.axis.reduce(4, x)
T.reads(in0[yi, 0 : 16], in1[yi, 0 : 16])
T.writes(out[yi, 0 : 4])
A = T.match_buffer(in0[yi, 0 : 16], [16], dtype="int8", offset_factor=1)
B = T.match_buffer(in1[yi, 0 : 16], [16], dtype="int8", offset_factor=1)
C = T.match_buffer(out[yi, 0 : 4], [4], dtype="int32", offset_factor=1)
with T.init():
for i in T.serial(4):
with T.block("C_init"):
ii = T.axis.spatial(4, i)
T.reads()
T.writes(out[yi, ii])
out[yi, ii] = 0
for x in T.serial(4):
with T.block("C"):
T.reads(out[yi, xr], in0[yi, yi * 4 + xr : yi * 4 + xr + 4], in1[yi, yi * 4 + xr : yi * 4 + xr + 4])
T.writes(out[yi, xr])
A_i8x4: T.int8x4 = A[yi * 4 + xr:yi * 4 + xr + 4]
A_i32: T.int32 = T.reinterpret(A_i8x4, dtype="int32")
B_i8x4: T.int8x4 = B[yi * 4 + xr:yi * 4 + xr + 4]
B_i32: T.int32 = T.reinterpret(B_i8x4, dtype="int32")
C[xr] = A_i32 + B_i32 + C[xr]where we may see some use-before-def issues
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i would love to temporarily exclude this particular TVMScript from testing, but happy to merge it back if you have a follow-up PR to fix :-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh interesting! I see so the reduction axis is tied to the block despite being within the loop. I can follow up with a commit in a few days or you can remove the test if its causing issues.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
apache#13301) * [TIR] Update ReductionIterNotIndexOutputBuffer to check BlockRealizeNodes match_buffer statements when validating writes * Add test to verify that tensorized blocks are properly validated * update to take into account all match buffer regions. * lint
…odes match_buffer statements when validating writes
Previously this check did not take into account any match_buffer statements and consequently would fail for tensorized schedules. Now it takes these into account when possible.
cc @vinx13 @Hzfengsy