-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] Change the behavior of read/write region analysis for reduction blocks. #10638
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
We failed to address the case of outer product: @T.prim_func
def outer_product_desc(a: T.handle, b: T.handle, c: T.handle) -> None:
A = T.match_buffer(a, (16, 1), offset_factor=1)
B = T.match_buffer(b, (16, 1), offset_factor=1)
C = T.match_buffer(c, (16, 16), offset_factor=1)
with T.block("root"):
T.reads(
A[0 : 16, 0 : 1],
B[0 : 16, 0 : 1],
)
T.writes(C[0 : 16, 0 : 16])
for i, j in T.grid(16, 16):
with T.block("update"):
vii, vjj = T.axis.remap("SS", [i, j])
C[vii, vjj] = C[vii, vjj] + A[vii, 0] * B[vjj, 0]Note that inside block |
598a500 to
ee6c2b1
Compare
|
@Hzfengsy @spectrometerHBH This PR is ready for review now. |
|
The outer product issue is resolved because it's a step in a reduction rather than a reduction block, we should not hide the write region from read region. |
Hzfengsy
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a testcase looks like
for i, j, k in T.grid(...):
with T.block():
T.reads(C, A, B)
T.writes(C)
C[..] += A[..] * B[..]for a reduction block(including init part and update part), C should not appear in the read region. But for a single update part, C in the read region is necessary.
|
@Hzfengsy thanks! I'll add the test. |
77478b2 to
c641c30
Compare
| for (const BufferRegion& read_access : block->reads) { | ||
| new_reads.push_back(read_access); | ||
| } | ||
| (const_cast<BlockNode*>(block))->reads = std::move(new_reads); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After this PR, I'm getting a structural hash mismatch in meta schedule ApplyHistoryBest
tvm/src/meta_schedule/integration.cc
Line 134 in ce335c3
| if (database->HasWorkload(prim_mod)) { |
I found that if I do decompose_reduction, IRModules that have been committed to the database are modified. Specifically, the query module has
reads([placeholder[b, i, k], T_layout_trans[b, floordiv(j, 16), floordiv(k, 4), floormod(j, 16), floormod(k, 4)]])
writes([compute[b, i, j]])
...
with init() {
compute[b, i, j] = 0
}
compute[b, i, j] = (compute[b, i, j] + (int32(placeholder[b, i, k])*int32(T_layout_trans[b, floordiv(j, 16), floordiv(k, 4), floormod(j, 16), floormod(k, 4)])))
but the corresponding mod in the database has different reads:
reads(compute[b, i, j], [placeholder[b, i, k], T_layout_trans[b, floordiv(j, 16), floordiv(k, 4), floormod(j, 16), floormod(k, 4)]])
...
I haven't looked at what this PR does, but I'm pretty sure this line is the offending bug...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@masahi thanks for reporting this! I'm not sure why the IRModule is mutated in place after decompose-reduction (not sure if i understand correctly), but would be great to have a minimal reproducible example after you returning from vacation :-)
@yzh119 would you mind taking a look at the particular case Masa mentions and perhaps add a regression test?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah, I can prepare a repro. It's a bit complicated since it involves scheduling, database stuff, and ApplyHistoryBest. I know the issue comes from this PR since I bisected it.
So when decompose_reduction is called, we are now executing
and I'm guessing that this in-place update propagates all the way up to the committed prim_mod in the tuning database.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok the repro is ready: https://gist.github.com/masahi/591078723d26f09ece3430af95835c99
The test works with the current main. When I run it, I get this output:
With decompose_reduction=True
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_nn_contrib_dense_pack
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_nn_batch_matmul
With decompose_reduction=False
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast
[07:20:13] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1
This means, if decompose_reduction is applied during scheduling, ApplyHistoryBest fails to match the structural hash of the query mod, corresponding to nn_contrib_dense and nn_batch_matmul, against the ones in database.
If I revert this commit, I get the following expected output
With decompose_reduction=True
One or more operators have not been tuned. Please tune your model for better performance. Use DEBUG logging level to see more details.
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1
With decompose_reduction=False
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_expand_dims
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast
[07:24:21] /home/masa/projects/dev/tvm/src/meta_schedule/integration.cc:146: Warning: Cannot find workload: tvmgen_default_fused_cast_1
You can dump the contents of the database at
tvm/src/meta_schedule/database/json_database.cc
Lines 72 to 74 in b08e8e4
| bool HasWorkload(const IRModule& mod) { | |
| return workloads2idx_.find(Workload(mod, tvm::StructuralHash()(mod))) != workloads2idx_.end(); | |
| } |
reads region for dense and bmm is modified if decompose_reduction is enabled. This causes the structual hash mismatch between the modules in the database and the query mod.
… blocks. (apache#10638) After discussion w/ @spectrometerHBH @Hzfengsy , we decide to exclude the buffer access from read regions if it's being written to inside a reduction block. In this way, the outer block would not find overlap between the region reads and writes simultaneously, thus solving the issue mentioned in apache#10420 . One tricky case is how to handle opaque memory access in `GetBlockReadWriteRegion`, where we have no hint about which buffer is being written to. And I keep the original behavior that the opaque access was added to both read and write regions of a block, no matter whether it's a reduction block or not.
After discussion w/ @spectrometerHBH @Hzfengsy , we decide to exclude the buffer access from read regions if it's being written to inside a reduction block. In this way, the outer block would not find overlap between the region reads and writes simultaneously, thus solving the issue mentioned in #10420 .
One tricky case is how to handle opaque memory access in
GetBlockReadWriteRegion, where we have no hint about which buffer is being written to. And I keep the original behavior that the opaque access was added to both read and write regions of a block, no matter whether it's a reduction block or not.