-
Notifications
You must be signed in to change notification settings - Fork 3.8k
[TIR] Fix buffer shape and IndexMap indices dtype mismatch #13463
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.
Generated by tvm-bot |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, and the change looks reasonable. The MapIndices method already uses SubstituteWithDataTypeLegalization, so it makes sense for the transform layout to use it as well.
|
Regarding the readability of the unit test, I did some poking around, and the following PrimFunc triggers the same error on main when the layout is transformed using @T.prim_func
def func(A: T.Buffer[T.int64(58), "int32"]):
for i in T.serial(T.int64(58)):
with T.block("block"):
vi = T.axis.remap("S", [i])
A[vi] = 0 |
|
@Lunderberg Thanks for finding the simple test case! I hit this error when running this Hexagon test
cache_read, index map etc, was directly taken from this Hexagon test.
The bug is triggered when we hit the code path
I was not sure what makes this code path hit and why existing tests didn't hit it. |
0bc7436 to
aa7f08e
Compare
|
@Lunderberg @vinx13 PTAL, thanks. |
Lunderberg
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for making the changes, and LGTM!
After the PR #13327, I'm getting a dtype-mismatch error at
tvm/src/tir/schedule/primitive/layout_transformation.cc
Line 310 in b6fae9b
dim.dtype()is now int64 whilevirtual_var.dtype()is int32. The dtypes ofinitial_indicesinIndexMapare fixed to int32 (see below), this is in conflict with the above PR which made the dtypes of buffer shapes int64.tvm/src/tir/ir/index_map.cc
Line 51 in 458ca81
tvm/python/tvm/tir/function.py
Line 395 in 78b5322
Since
initial_indicesis used everywhere to constructloop_var/iter_var/iter_valuesofForandBlocketc, I'm adding a dtype legalization at the beginning ofTransformLayout, when the dtypes of the input buffer andinitial_indicesdo not match.@vinx13 @Lunderberg