Skip to content

Conversation

@masahi
Copy link
Member

@masahi masahi commented Nov 22, 2022

After the PR #13327, I'm getting a dtype-mismatch error at

IterVar(Range::FromMinExtent(make_zero(dim.dtype()), dim), virtual_var, kDataPar));
, because dim.dtype() is now int64 while virtual_var.dtype() is int32. The dtypes of initial_indices in IndexMap are fixed to int32 (see below), this is in conflict with the above PR which made the dtypes of buffer shapes int64.

initial_indices.push_back(Var("i" + std::to_string(i), DataType::Int(32)));

default_index_dtype = "int32"

Since initial_indices is used everywhere to construct loop_var / iter_var / iter_values of For and Block etc, I'm adding a dtype legalization at the beginning of TransformLayout, when the dtypes of the input buffer and initial_indices do not match.

@vinx13 @Lunderberg

@tvm-bot
Copy link
Collaborator

tvm-bot commented Nov 22, 2022

Thanks for contributing to TVM! Please refer to the contributing guidelines https://tvm.apache.org/docs/contribute/ for useful information and tips. Please request code reviews from Reviewers by @-ing them in a comment.

Generated by tvm-bot

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, and the change looks reasonable. The MapIndices method already uses SubstituteWithDataTypeLegalization, so it makes sense for the transform layout to use it as well.

@Lunderberg
Copy link
Contributor

Regarding the readability of the unit test, I did some poking around, and the following PrimFunc triggers the same error on main when the layout is transformed using lambda h: [h//8, h%8].

@T.prim_func
def func(A: T.Buffer[T.int64(58), "int32"]):
    for i in T.serial(T.int64(58)):
        with T.block("block"):
            vi = T.axis.remap("S", [i])
            A[vi] = 0

@masahi
Copy link
Member Author

masahi commented Nov 22, 2022

@Lunderberg Thanks for finding the simple test case! I hit this error when running this Hexagon test

def test_packed_8x8x32_resnet50(hexagon_launcher):
. The original test case in this PR, with its cache_read, index map etc, was directly taken from this Hexagon test.

The bug is triggered when we hit the code path

// If we got to this point, all indices used to access the

I was not sure what makes this code path hit and why existing tests didn't hit it.

@masahi masahi force-pushed the transform-layout-dtype-fix branch from 0bc7436 to aa7f08e Compare November 22, 2022 21:10
@masahi
Copy link
Member Author

masahi commented Nov 28, 2022

@Lunderberg @vinx13 PTAL, thanks.

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for making the changes, and LGTM!

@vinx13 vinx13 merged commit 36d18e9 into apache:main Nov 28, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants